diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 1be09182c..7b78c14fc 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer { SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); - SDVariable[] vars; - if(hasBias){ - SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - vars = new SDVariable[]{layerInput, w, b}; - } else { - vars = new SDVariable[]{layerInput, w}; - } - Conv2DConfig c = Conv2DConfig.builder() .kH(kernel[0]).kW(kernel[1]) .pH(padding[0]).pW(padding[1]) @@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer { .isSameMode(this.cm == ConvolutionMode.Same) .build(); - SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name + SDVariable conv = null; + if(hasBias){ + SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); + conv = sameDiff.cnn().conv2d(layerInput, w, b, c); + } else { + conv = sameDiff.cnn().conv2d(layerInput, w, c); + } return activation.asSameDiff("out", sameDiff, conv); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index 4b4a69159..98bda1f3a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.util.ArrayUtil; import java.util.Map; @@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer { } @Override - public SDVariable defineLayer(SameDiff SD, SDVariable input, Map paramTable, SDVariable mask) { + public SDVariable defineLayer(SameDiff sd, SDVariable input, Map paramTable, SDVariable mask) { // input: [mb, inputCapsules, inputCapsuleDimensions] // [mb, inputCapsules, 1, inputCapsuleDimensions, 1] - SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4); + SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4); // [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1] - SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); + SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions] SDVariable weights = paramTable.get(WEIGHT_PARAM); @@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer { // b is the logits of the routing procedure // [mb, inputCapsules, capsules, 1, 1] - SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); + SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); for(int i = 0 ; i < routings ; i++){ // c is the coupling coefficient, i.e. the edge weight between the 2 capsules // [mb, inputCapsules, capsules, 1, 1] - SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5); + SDVariable c = sd.nn.softmax(b, 2); // [mb, 1, capsules, capsuleDimensions, 1] SDVariable s = c.times(uHat).sum(true, 1); @@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer { // v is the per capsule activations. On the last routing iteration, this is output // [mb, 1, capsules, capsuleDimensions, 1] - SDVariable v = CapsuleUtils.squash(SD, s, 3); + SDVariable v = CapsuleUtils.squash(sd, s, 3); if(i == routings - 1){ - return SD.squeeze(SD.squeeze(v, 1), 3); + return sd.squeeze(sd.squeeze(v, 1), 3); } // [mb, inputCapsules, capsules, capsuleDimensions, 1] - SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1); + SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1); // [mb, inputCapsules, capsules, 1, 1] b = b.plus(uHat.times(vTiled).sum(true, 3)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index fc805f0ca..60ecbf057 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer { //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCW format. if(cm == ConvolutionMode.Same) { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0); } else { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index ef07c9dc5..5044017a0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer { //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCHW format if(cm == ConvolutionMode.Same){ - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0); } else { - layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0); + layerInput = sameDiff.nn().pad(layerInput, + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java index ff605d028..66d732907 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/CapsuleUtils.java @@ -45,15 +45,4 @@ public class CapsuleUtils { return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale)); } - /** - * Compute softmax along a given dimension - */ - public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){ - int[] permutation = ArrayUtil.range(0, rank); - permutation[0] = dimension; - permutation[dimension] = 0; - - return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation)); - } - } diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index 8b060a77c..e94ffbb40 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest { SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); - SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); + SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1); val server = new JsonModelServer.Builder(sd) .outputSerializer( new IntSerde()) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java index 4ba24eafa..7401874d3 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest { SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); SDVariable z = in.mmul(w).add(b); - SDVariable a = sd.nn().tanh(z); + SDVariable a = sd.math().tanh(z); LogFileWriter lfw = new LogFileWriter(f); lfw.writeGraphStructure(sd); diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java index a3fdc0c3f..ced461089 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java @@ -20,6 +20,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.TestCase; +import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 8a9bd8edc..fcb63ea0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -28,6 +28,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; @@ -1489,7 +1490,7 @@ public class DifferentialFunctionFactory { } public SDVariable reciprocal(SDVariable a) { - return new Reciprocal(sameDiff(), a, false).outputVariable(); + return new Reciprocal(sameDiff(), a).outputVariable(); } @@ -1990,13 +1991,13 @@ public class DifferentialFunctionFactory { .outputVariable(); } - public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, String dataFormat) { + public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { validateDifferentialFunctionsameDiff(differentialFunction); return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) .outputVariable(); } - public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, String dataFormat) { + public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { validateDifferentialFunctionsameDiff(differentialFunction); return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) .outputVariable(); @@ -2635,7 +2636,7 @@ public class DifferentialFunctionFactory { return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); } - public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) { + public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) { return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 3411e2007..ab3279fd0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -181,6 +181,11 @@ public class SameDiff extends SDBaseOps { */ public final SDBitwise bitwise = new SDBitwise(this); + /** + * Op creator object for linalg operations + */ + public final SDLinalg linalg = new SDLinalg(this); + /** * Op creator object for math operations */ @@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps { return bitwise; } + /** + * Op creator object for linalg operations + */ + public SDLinalg linalg(){ + return linalg; + } + private Map sameDiffFunctionInstances; private Table fieldVariableResolutionMapping; @@ -3448,6 +3460,12 @@ public class SameDiff extends SDBaseOps { sd.renameVariable(from, to); } } + + //Check losses: + if(lossVariables.contains(from)){ + int idx = lossVariables.indexOf(from); + lossVariables.set(idx, to); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index a255afbc3..956444ffe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -1,217 +1,416 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; +public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { + super(sameDiff); + } -/** - * - */ -public class SDBitwise extends SDOps { - public SDBitwise(SameDiff sameDiff) { - super(sameDiff); - } + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public SDVariable and(SDVariable x, SDVariable y) { + SDValidation.validateInteger("and", "x", x); + SDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + } - /** - * See {@link #leftShift(String, SDVariable, SDVariable)} - */ - public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ - return leftShift(null, x, y); - } + /** + * Bitwise AND operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y Second input array (INT type) + * @return output Bitwise AND array (INT type) + */ + public SDVariable and(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("and", "x", x); + SDValidation.validateInteger("and", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise left shift operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise shifted input x - */ - public SDVariable leftShift(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise left shift", x); - validateInteger("bitwise left shift", y); + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotl(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotl", "x", x); + SDValidation.validateInteger("bitRotl", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + } - SDVariable ret = f().shift(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotl", "x", x); + SDValidation.validateInteger("bitRotl", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #rightShift(String, SDVariable, SDVariable)} - */ - public SDVariable rightShift(SDVariable x, SDVariable y){ - return rightShift(null, x, y); - } + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotr(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotr", "x", x); + SDValidation.validateInteger("bitRotr", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + } - /** - * Bitwise right shift operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise shifted input x - */ - public SDVariable rightShift(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise right shift", x); - validateInteger("bitwise right shift", y); + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitRotr", "x", x); + SDValidation.validateInteger("bitRotr", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().rshift(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShift(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShift", "x", x); + SDValidation.validateInteger("bitShift", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + } - /** - * See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} - */ - public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ - return leftShiftCyclic(null, x, y); - } + /** + * Shift integer bits to the left, i.e. var << 4
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShift", "x", x); + SDValidation.validateInteger("bitShift", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise left cyclical shift operation. Supports broadcasting. - * Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": - * {@code leftShiftCyclic(01110000, 2) -> 11000001} - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise cyclic shifted input x - */ - public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise left shift (cyclic)", x); - validateInteger("bitwise left shift (cyclic)", y); + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShiftRight", "x", x); + SDValidation.validateInteger("bitShiftRight", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + } - SDVariable ret = f().rotl(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Shift integer bits to the right, i.e. var >> 4
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) + * @param shift Number of bits to shift. (INT type) + * @return output SDVariable with shifted bits (INT type) + */ + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { + SDValidation.validateInteger("bitShiftRight", "x", x); + SDValidation.validateInteger("bitShiftRight", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} - */ - public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ - return rightShiftCyclic(null, x, y); - } + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public SDVariable bitsHammingDistance(SDVariable x, SDVariable y) { + SDValidation.validateInteger("bitsHammingDistance", "x", x); + SDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + } - /** - * Bitwise right cyclical shift operation. Supports broadcasting. - * Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": - * {@code rightShiftCyclic(00001110, 2) -> 10000011} - * - * @param name Name of the output variable. May be null. - * @param x Input to be bit shifted (must be an integer type) - * @param y Amount to shift elements of x array (must be an integer type) - * @return Bitwise cyclic shifted input x - */ - public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise right shift (cyclic)", x); - validateInteger("bitwise right shift (cyclic)", y); + /** + * Bitwise Hamming distance reduction over all elements of both input arrays.
+ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array. (INT type) + * @param y Second input array. (INT type) + * @return output bitwise Hamming distance (INT type) + */ + public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("bitsHammingDistance", "x", x); + SDValidation.validateInteger("bitsHammingDistance", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().rotr(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable leftShift(SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShift", "x", x); + SDValidation.validateInteger("leftShift", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + } - /** - * See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} - */ - public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ - return bitsHammingDistance(null, x, y); - } + /** + * Bitwise left shift operation. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable leftShift(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShift", "x", x); + SDValidation.validateInteger("leftShift", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return - */ - public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise hamming distance", x); - validateInteger("bitwise hamming distance", y); + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShiftCyclic", "x", x); + SDValidation.validateInteger("leftShiftCyclic", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + } - SDVariable ret = f().bitwiseHammingDist(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise left cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code leftShiftCyclic(01110000, 2) -> 11000001}
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("leftShiftCyclic", "x", x); + SDValidation.validateInteger("leftShiftCyclic", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #and(String, SDVariable, SDVariable)} - */ - public SDVariable and(SDVariable x, SDVariable y){ - return and(null, x, y); - } + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public SDVariable or(SDVariable x, SDVariable y) { + SDValidation.validateInteger("or", "x", x); + SDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + } - /** - * Bitwise AND operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise AND array - */ - public SDVariable and(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise AND", x); - validateInteger("bitwise AND", y); + /** + * Bitwise OR operation. Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise OR array (INT type) + */ + public SDVariable or(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("or", "x", x); + SDValidation.validateInteger("or", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - SDVariable ret = f().bitwiseAnd(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable rightShift(SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShift", "x", x); + SDValidation.validateInteger("rightShift", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + } - /** - * See {@link #or(String, SDVariable, SDVariable)} - */ - public SDVariable or(SDVariable x, SDVariable y){ - return or(null, x, y); - } + /** + * Bitwise right shift operation. Supports broadcasting.
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise shifted input x (INT type) + */ + public SDVariable rightShift(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShift", "x", x); + SDValidation.validateInteger("rightShift", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Bitwise OR operation. Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise OR array - */ - public SDVariable or(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise OR", x); - validateInteger("bitwise OR", y); + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShiftCyclic", "x", x); + SDValidation.validateInteger("rightShiftCyclic", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + } - SDVariable ret = f().bitwiseOr(x, y); - return updateVariableNameAndReference(ret, name); - } + /** + * Bitwise right cyclical shift operation. Supports broadcasting.
+ * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * {@code rightShiftCyclic(00001110, 2) -> 10000011}
+ * + * @param name name May be null. Name for the output variable + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) + * @return output Bitwise cyclic shifted input x (INT type) + */ + public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("rightShiftCyclic", "x", x); + SDValidation.validateInteger("rightShiftCyclic", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #xor(String, SDVariable, SDVariable)} - */ - public SDVariable xor(SDVariable x, SDVariable y){ - return xor(null, x, y); - } + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public SDVariable xor(SDVariable x, SDVariable y) { + SDValidation.validateInteger("xor", "x", x); + SDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + } - /** - * Bitwise XOR operation (exclusive OR). Supports broadcasting. - * - * @param name Name of the output variable. May be null. - * @param x First input array. Must be integer type. - * @param y First input array. Must be integer type, same type as x - * @return Bitwise XOR array - */ - public SDVariable xor(String name, SDVariable x, SDVariable y){ - validateInteger("bitwise XOR", x); - validateInteger("bitwise XOR", y); - - SDVariable ret = f().bitwiseXor(x, y); - return updateVariableNameAndReference(ret, name); - } - - /** - * Flip bits - * - * @param name Name of the output variable - * @param x input array - * @return array after flipping each input bit - */ - public SDVariable toggleBits(String name, SDVariable x) { - SDVariable res = f().toggleBits(x); - return updateVariableNameAndReference(res, name); - } + /** + * Bitwise XOR operation (exclusive OR). Supports broadcasting.
+ * + * Inputs must satisfy the following constraints:
+ * Must be same types: isSameType(x, y)
+ * Must have broadcastable shapes: isBroadcastableShapes(x, y)
+ * + * @param name name May be null. Name for the output variable + * @param x First input array (INT type) + * @param y First input array (INT type) + * @return output Bitwise XOR array (INT type) + */ + public SDVariable xor(String name, SDVariable x, SDVariable y) { + SDValidation.validateInteger("xor", "x", x); + SDValidation.validateInteger("xor", "y", y); + Preconditions.checkArgument(isSameType(x, y), "Must be same types"); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index 7b56ca266..d367e3d4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,777 +14,1015 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.enums.DataFormat; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateNumerical; - -/** - * SameDiff Convolutional Neural Network operations - CNN1d, 2d and 3d ops - as well as related functions.
- * Accessible via {@link SameDiff#cnn()}
- * See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.
- * See also {@link SDRNN} (accessible via {@link SameDiff#rnn()} for recurrent neural network ops.
- * - * @author Alex Black - */ public class SDCNN extends SDOps { - - public SDCNN(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * See {@link #avgPooling2d(String, SDVariable, Pooling2DConfig)}. - */ - public SDVariable avgPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - return avgPooling2d(null, input, pooling2DConfig); - } - - /** - * 2D Convolution layer operation - average pooling 2d - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying average pooling on the input - */ - public SDVariable avgPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - validateFloatingPoint("avgPooling2d", input); - SDVariable ret = f().avgPooling2d(input, pooling2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #avgPooling3d(String, SDVariable, Pooling3DConfig)}. - */ - public SDVariable avgPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - return avgPooling3d(null, input, pooling3DConfig); - } - - /** - * 3D convolution layer operation - average pooling 3d - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying average pooling on the input - */ - public SDVariable avgPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - validateFloatingPoint("avgPooling3d", input); - SDVariable ret = f().avgPooling3d(input, pooling3DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #batchToSpace(String, SDVariable, int[], int[][]) - */ - public SDVariable batchToSpace(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { - return batchToSpace(null, x, blocks, crops); - } - - /** - * Convolution 2d layer batch to space operation on 4d input. - * Reduces input batch dimension by rearranging data into a larger spatial dimensions - * - * @param name Output variable name - * @param x Input variable. 4d input - * @param blocks Block size, in the height/width dimension - * @param crops Optional 2d int[] array: values [[crop top, crop bottom], [crop left, crop right]] - * @return Output variable - * @see #spaceToBatch(String, SDVariable, int[], int[][]) - */ - public SDVariable batchToSpace(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] crops) { - validateNumerical("batchToSpace", x); - SDVariable ret = f().batchToSpace(x, blocks, crops); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #col2Im(String, SDVariable, Conv2DConfig)}. - */ - public SDVariable col2Im(@NonNull SDVariable in, @NonNull Conv2DConfig config) { - return col2Im(null, in, config); - } - - /** - * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape - * [minibatch, inputChannels, height, width] - * - * @param name Name of the output variable - * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * @param config Convolution configuration for the col2im operation - * @return Col2Im output variable - */ - public SDVariable col2Im(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { - SDVariable ret = f().col2Im(in, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. - */ - public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { - return conv1d((String) null, input, weights, conv1DConfig); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}, no bias. - */ - public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv1DConfig conv1DConfig) { - validateFloatingPoint("conv1d", input); - validateFloatingPoint("conv1d", weights); - SDVariable ret = f().conv1d(input, weights, conv1DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv1d(String, SDVariable, SDVariable, SDVariable, Conv1DConfig)}. - */ - public SDVariable conv1d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { - return conv1d(null, input, weights, bias, conv1DConfig); - } - - /** - * Conv1d operation. - * - * @param name name of the operation in SameDiff - * @param input the inputs to conv1d - * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] - * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { - validateFloatingPoint("conv1d", input); - validateFloatingPoint("conv1d", weights); - validateFloatingPoint("conv1d", bias); - SDVariable ret = f().conv1d(input, weights, bias, conv1DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { - return conv2d(layerInput, weights, null, config); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull Conv2DConfig config) { - return conv2d(name, layerInput, weights, null, config); - } - - /** - * See {@link #conv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable conv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { - return conv2d(null, layerInput, weights, bias, config); - } - - /** - * 2D Convolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of conv2d op - */ - public SDVariable conv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("conv2d", "input", layerInput); - validateFloatingPoint("conv2d", "weights", weights); - validateFloatingPoint("conv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = weights; - if (bias != null) - arr[2] = bias; - return conv2d(name, arr, config); - } - - /** - * See {@link #conv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable conv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { - return conv2d(null, inputs, config); - } - - /** - * 2D Convolution operation with optional bias - * - * @param name Name of the output SDVariable - * @param inputs an array with either 2 elements (layerInput, weights) or 3 elements (layerInput, weights, bias) as - * described in {@link #conv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param config Conv2DConfig configuration - * @return result of convolution 2d operation - */ - public SDVariable conv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig config) { - for(SDVariable v : inputs) - validateNumerical("conv2d", v); - SDVariable ret = f().conv2d(inputs, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. - */ - public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, null, conv3DConfig); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}, no bias. - */ - public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(name, input, weights, null, conv3DConfig); - } - - /** - * See {@link #conv3d(String, SDVariable, SDVariable, SDVariable, Conv3DConfig)}. - */ - public SDVariable conv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { - return conv3d(null, input, weights, bias, conv3DConfig); - } - - /** - * Convolution 3D operation with optional bias - * - * @param name Name of the output variable - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param conv3DConfig the configuration - * @return Conv3d output variable - */ - public SDVariable conv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv3DConfig conv3DConfig) { - validateFloatingPoint("conv3d", "input", input); - validateFloatingPoint("conv3d", "weights", weights); - validateFloatingPoint("conv3d", "bias", bias); - SDVariable[] args; - if (bias == null) { - args = new SDVariable[]{input, weights}; - } else { - args = new SDVariable[]{input, weights, bias}; - } - SDVariable ret = f().conv3d(args, conv3DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. - */ - public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(layerInput, weights, null, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}, no bias. - */ - public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(name, layerInput, weights, null, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable, SDVariable, SDVariable, DeConv2DConfig)}. - */ - public SDVariable deconv2d(@NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(null, layerInput, weights, bias, deconv2DConfig); - } - - /** - * 2D deconvolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth]. - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param deconv2DConfig DeConv2DConfig configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv2DConfig deconv2DConfig) { - validateFloatingPoint("deconv2d", "input", layerInput); - validateFloatingPoint("deconv2d", "weights", weights); - validateFloatingPoint("deconv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = weights; - if (bias != null) - arr[2] = bias; - return deconv2d(name, arr, deconv2DConfig); - } - - /** - * See {@link #deconv2d(String, SDVariable[], DeConv2DConfig)}. - */ - public SDVariable deconv2d(@NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { - return deconv2d(null, inputs, deconv2DConfig); - } - - /** - * 2D deconvolution operation with or without optional bias - * - * @param name Name of the output variable - * @param inputs Inputs to the deconvolution 2d operation - input array of length 2 (layerInput, weights) - * or length 3 (layerInput, weights, bias) as described in {@link #deconv2d(SDVariable[], DeConv2DConfig)} - * @param deconv2DConfig the configuration - * @return result of deconv2d op - */ - public SDVariable deconv2d(String name, @NonNull SDVariable[] inputs, @NonNull DeConv2DConfig deconv2DConfig) { - for(SDVariable v : inputs) - validateNumerical("deconv2d", v); - SDVariable ret = f().deconv2d(inputs, deconv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. - */ - public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { - return deconv3d(input, weights, null, config); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}, no bias. - */ - public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { - return deconv3d(name, input, weights, null, config); - } - - /** - * See {@link #deconv3d(String, SDVariable, SDVariable, SDVariable, DeConv3DConfig)}. - */ - public SDVariable deconv3d(@NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { - return deconv3d(null, input, weights, bias, config); - } - - /** - * 3D CNN deconvolution operation with or without optional bias - * - * @param name Name of the output variable - * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - * @param weights Weights array - shape [kD, kH, kW, oC, iC] - * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] - * @param config Configuration - */ - public SDVariable deconv3d(String name, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { - validateFloatingPoint("conv3d", input); - validateFloatingPoint("conv3d", weights); - validateFloatingPoint("conv3d", bias); - SDVariable ret = f().deconv3d(input, weights, bias, config); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #depthToSpace(String, SDVariable, int, String)}. - */ - public SDVariable depthToSpace(@NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { - return depthToSpace(null, x, blockSize, dataFormat); - } - - /** - * Convolution 2d layer batch to space operation on 4d input.
- * Reduces input channels dimension by rearranging data into a larger spatial dimensions
- * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param name Output variable name - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - * @see #depthToSpace(String, SDVariable, int, String) - */ - public SDVariable depthToSpace(String name, @NonNull SDVariable x, @NonNull int blockSize, @NonNull String dataFormat) { - SDVariable ret = f().depthToSpace(x, blockSize, dataFormat); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { - return depthWiseConv2d(layerInput, depthWeights, null, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, @NonNull Conv2DConfig config) { - return depthWiseConv2d(name, layerInput, depthWeights, null, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { - return depthWiseConv2d(null, layerInput, depthWeights, bias, config); - } - - /** - * Depth-wise 2D convolution operation with optional bias - * - * @param name name of the operation in SameDiff - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param bias Optional 1D bias array with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of depthwise conv2d op - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("depthwiseConv2d", "input", layerInput); - validateFloatingPoint("depthwiseConv2d", "depth weights", depthWeights); - validateFloatingPoint("depthwiseConv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 2 : 3]; - arr[0] = layerInput; - arr[1] = depthWeights; - if (bias != null) - arr[2] = bias; - return depthWiseConv2d(name, arr, config); - } - - /** - * See {@link #depthWiseConv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable depthWiseConv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { - return depthWiseConv2d(null, inputs, depthConv2DConfig); - } - - /** - * Depth-wise convolution 2D operation. - * - * @param name name of the output variable - * @param inputs the inputs to depth-wise conv2d. An array with either 2 elements (layerInput, depthWeights) - * or 3 elements (layerInput, depthWeights, bias) as described in - * {@link #depthWiseConv2d(SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param depthConv2DConfig the configuration - * @return result of depthwise conv2d op - */ - public SDVariable depthWiseConv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig depthConv2DConfig) { - for(SDVariable v : inputs) - validateFloatingPoint("depthWiseConv2d", v); - SDVariable ret = f().depthWiseConv2d(inputs, depthConv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #dilation2D(String, SDVariable, SDVariable, int[], int[], boolean)}. - */ - public SDVariable dilation2D(@NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, - @NonNull int[] rates, @NonNull boolean isSameMode) { - return dilation2D(null, df, weights, strides, rates, isSameMode); - } - - /** - * TODO doc string - * - * @param name - * @param df - * @param weights - * @param strides - * @param rates - * @param isSameMode - * @return - */ - public SDVariable dilation2D(String name, @NonNull SDVariable df, @NonNull SDVariable weights, @NonNull int[] strides, - @NonNull int[] rates, @NonNull boolean isSameMode) { - SDVariable ret = f().dilation2D(df, weights, strides, rates, isSameMode); - return updateVariableNameAndReference(ret, name); - } - - - /** - * Extract image patches - * - * @param name Name of the output variable - * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] - * @param kH Kernel height - * @param kW Kernel width - * @param sH Stride height - * @param sW Stride width - * @param rH Rate height - * @param rW Rate width - * @param sameMode If true: use same mode padding. If false - * @return - */ - public SDVariable extractImagePatches(String name, @NonNull SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { - SDVariable ret = f().extractImagePatches(input, kH, kW, sH, sW, rH, rW, sameMode); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #im2Col(String, SDVariable, Conv2DConfig)}. - */ - public SDVariable im2Col(@NonNull SDVariable in, @NonNull Conv2DConfig config) { - return im2Col(null, in, config); - } - - /** - * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape - * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] - * - * @param name Name of the output variable - * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] - * @param config Convolution configuration for the im2col operation - * @return Im2Col output variable - */ - public SDVariable im2Col(String name, @NonNull SDVariable in, @NonNull Conv2DConfig config) { - SDVariable ret = f().im2Col(in, config); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #localResponseNormalization(String, SDVariable, LocalResponseNormalizationConfig)}. - */ - public SDVariable localResponseNormalization(@NonNull SDVariable inputs, @NonNull LocalResponseNormalizationConfig lrnConfig) { - return localResponseNormalization(null, inputs, lrnConfig); - } - - /** - * 2D convolution layer operation - local response normalization - * - * @param name name of the operation in SameDiff - * @param input the inputs to lrn - * @param lrnConfig the configuration - * @return - */ - public SDVariable localResponseNormalization(String name, @NonNull SDVariable input, - @NonNull LocalResponseNormalizationConfig lrnConfig) { - validateFloatingPoint("local response normalization", input); - SDVariable ret = f().localResponseNormalization(input, lrnConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #maxPooling2d(String, SDVariable, Pooling2DConfig)}. - */ - public SDVariable maxPooling2d(@NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - return maxPooling2d(null, input, pooling2DConfig); - } - - /** - * 2D Convolution layer operation - max pooling 2d - * - * @param name name of the operation in SameDiff - * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param pooling2DConfig the configuration - * @return Result after applying max pooling on the input - */ - public SDVariable maxPooling2d(String name, @NonNull SDVariable input, @NonNull Pooling2DConfig pooling2DConfig) { - validateNumerical("maxPooling2d", input); - SDVariable ret = f().maxPooling2d(input, pooling2DConfig); - return updateVariableNameAndReference(ret, name); - } - - /** - * See {@link #maxPooling3d(String, SDVariable, Pooling3DConfig)}. - */ - public SDVariable maxPooling3d(@NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - return maxPooling3d(null, input, pooling3DConfig); - } - - /** - * 3D convolution layer operation - max pooling 3d operation. - * - * @param name name of the operation in SameDiff - * @param input the input to average pooling 3d operation - 5d activations in NCDHW format - * (shape [minibatch, channels, depth, height, width]) or NDHWC format - * (shape [minibatch, depth, height, width, channels]) - * @param pooling3DConfig the configuration - * @return Result after applying max pooling on the input - */ - public SDVariable maxPooling3d(String name, @NonNull SDVariable input, @NonNull Pooling3DConfig pooling3DConfig) { - validateNumerical("maxPooling3d", input); - SDVariable ret = f().maxPooling3d(input, pooling3DConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable separableConv2d(SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - @NonNull Conv2DConfig config) { - return separableConv2d(layerInput, depthWeights, pointWeights, null, config); - } - - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}, no bias. - */ - public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - @NonNull Conv2DConfig config) { - return separableConv2d(layerInput, depthWeights, pointWeights, null, config); - } - - /** - * See {@link #separableConv2d(String, SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)}. - */ - public SDVariable separableConv2d(@NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, @NonNull Conv2DConfig config) { - return separableConv2d(null, layerInput, depthWeights, pointWeights, bias, config); - } - - /** - * Separable 2D convolution operation with optional bias - * - * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] - * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels] - * May be null - * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. - * @param config Conv2DConfig configuration - * @return result of separable convolution 2d operation - */ - public SDVariable separableConv2d(String name, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, SDVariable pointWeights, - SDVariable bias, @NonNull Conv2DConfig config) { - validateFloatingPoint("separableConv2d", "input", layerInput); - validateFloatingPoint("separableConv2d", "depthWeights", depthWeights); - validateFloatingPoint("separableConv2d", "pointWeights", pointWeights); - validateFloatingPoint("separableConv2d", "bias", bias); - SDVariable[] arr = new SDVariable[bias == null ? 3 : 4]; - arr[0] = layerInput; - arr[1] = depthWeights; - arr[2] = pointWeights; - if (bias != null) - arr[3] = bias; - return sconv2d(name, arr, config); - } - - /** - * See {@link #sconv2d(String, SDVariable[], Conv2DConfig)}. - */ - public SDVariable sconv2d(@NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { - return sconv2d(null, inputs, conv2DConfig); - } - - /** - * Separable 2D convolution operation with/without optional bias - * - * @param name name of the output variable - * @param inputs the inputs to separable conv2 operation. Should be length 3 (layerInput, depthWeights, pointWeights) - * or length 4 (layerInput, depthWeights, pointWeights, bias) as described in {@link #separableConv2d(SDVariable, SDVariable, SDVariable, SDVariable, Conv2DConfig)} - * @param conv2DConfig the configuration - * @return result of separable convolution 2d operation - */ - public SDVariable sconv2d(String name, @NonNull SDVariable[] inputs, @NonNull Conv2DConfig conv2DConfig) { - for(SDVariable v : inputs) - validateFloatingPoint("sconv2d", v); - SDVariable ret = f().sconv2d(inputs, conv2DConfig); - return updateVariableNameAndReference(ret, name); - } - - - /** - * @see #spaceToBatch(String, SDVariable, int[], int[][]) - */ - public SDVariable spaceToBatch(@NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { - return spaceToBatch(null, x, blocks, padding); - } - - /** - * Convolution 2d layer space to batch operation on 4d input. - * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension - * - * @param name Output variable name - * @param x Input variable. 4d input - * @param blocks Block size, in the height/width dimension - * @param padding Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] - * @return Output variable - * @see #batchToSpace(String, SDVariable, int[], int[][]) - */ - public SDVariable spaceToBatch(String name, @NonNull SDVariable x, @NonNull int[] blocks, @NonNull int[][] padding) { - SDVariable ret = f().spaceToBatch(x, blocks, padding); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #spaceToDepth(String, SDVariable, int, String) - */ - public SDVariable spaceToDepth(@NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { - return spaceToDepth(null, x, blockSize, dataFormat); - } - - /** - * Convolution 2d layer space to depth operation on 4d input.
- * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
- * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2] - * = [mb, 2, 4, 4] - * - * @param name Output variable name - * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format - * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) - * @param blockSize Block size, in the height/width dimension - * @param dataFormat Data format: "NCHW" or "NHWC" - * @return Output variable - * @see #depthToSpace(String, SDVariable, int, String) - */ - public SDVariable spaceToDepth(String name, @NonNull SDVariable x, int blockSize, @NonNull String dataFormat) { - SDVariable ret = f().spaceToDepth(x, blockSize, dataFormat); - return updateVariableNameAndReference(ret, name); - } - - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, - * scale is used for both height and width dimensions. - * - * @param scale The scale for both height and width dimensions. - */ - public SDVariable upsampling2d(@NonNull SDVariable input, int scale) { - return upsampling2d(null, input, true, scale, scale); - } - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}, - * scale is used for both height and width dimensions. - * - * @param scale The scale for both height and width dimensions. - */ - public SDVariable upsampling2d(String name, @NonNull SDVariable input, int scale) { - return upsampling2d(name, input, true, scale, scale); - } - - /** - * See {@link #upsampling2d(String, SDVariable, boolean, int, int)}. - */ - public SDVariable upsampling2d(@NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { - return upsampling2d(null, input, nchw, scaleH, scaleW); - } - - /** - * 2D Convolution layer operation - Upsampling 2d - * - * @param input Input, in NCHW format - * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format - * @param scaleH Scale to upsample in height dimension - * @param scaleW Scale to upsample in width dimension - * @return Upsampled input - */ - public SDVariable upsampling2d(String name, @NonNull SDVariable input, boolean nchw, int scaleH, int scaleW) { - SDVariable ret = f().upsampling2d(input, nchw, scaleH, scaleW); - return updateVariableNameAndReference(ret, name); - } + public SDCNN(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * 2D Convolution layer operation - average pooling 2d
+ * + * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("avgPooling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(sd,input, Pooling2DConfig).outputVariable(); + } + + /** + * 2D Convolution layer operation - average pooling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling2d(String name, SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("avgPooling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D(sd,input, Pooling2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D convolution layer operation - average pooling 3d
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("avgPooling3d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(sd,input, Pooling3DConfig).outputVariable(); + } + + /** + * 3D convolution layer operation - average pooling 3d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output after applying average pooling on the input (NUMERIC type) + */ + public SDVariable avgPooling3d(String name, SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("avgPooling3d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D(sd,input, Pooling3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input batch dimension by rearranging data into a larger spatial dimensions
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param croppingTop (Size: Exactly(count=2)) + * @param croppingBottom (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable batchToSpace(SDVariable x, int[] blocks, int[] croppingTop, + int... croppingBottom) { + SDValidation.validateNumerical("batchToSpace", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length); + Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(sd,x, blocks, croppingTop, croppingBottom).outputVariable(); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input batch dimension by rearranging data into a larger spatial dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param croppingTop (Size: Exactly(count=2)) + * @param croppingBottom (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable batchToSpace(String name, SDVariable x, int[] blocks, int[] croppingTop, + int... croppingBottom) { + SDValidation.validateNumerical("batchToSpace", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(croppingTop.length == 2, "croppingTop has incorrect size/length. Expected: croppingTop.length == 2, got %s", croppingTop.length); + Preconditions.checkArgument(croppingBottom.length == 2, "croppingBottom has incorrect size/length. Expected: croppingBottom.length == 2, got %s", croppingBottom.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace(sd,x, blocks, croppingTop, croppingBottom).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
+ * [minibatch, inputChannels, height, width]
+ * + * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Col2Im output variable (NUMERIC type) + */ + public SDVariable col2Im(SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("col2Im", "in", in); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(sd,in, Conv2DConfig).outputVariable(); + } + + /** + * col2im operation for use in 2D convolution operations. Outputs a 4d array with shape
+ * [minibatch, inputChannels, height, width]
+ * + * @param name name May be null. Name for the output variable + * @param in Input - rank 6 input with shape [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Col2Im output variable (NUMERIC type) + */ + public SDVariable col2Im(String name, SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("col2Im", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im(sd,in, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDValidation.validateNumerical("conv1d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, bias, Conv1DConfig).outputVariable(); + } + + /** + * Conv1d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param bias bias for conv1d op - rank 1 array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(String name, SDVariable input, SDVariable weights, SDVariable bias, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDValidation.validateNumerical("conv1d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, bias, Conv1DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Conv1d operation.
+ * + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, null, Conv1DConfig).outputVariable(); + } + + /** + * Conv1d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to conv1d (NUMERIC type) + * @param weights weights for conv1d op - rank 3 array with shape [kernelSize, inputChannels, outputChannels] (NUMERIC type) + * @param Conv1DConfig Configuration Object + * @return output result of conv1d op (NUMERIC type) + */ + public SDVariable conv1d(String name, SDVariable input, SDVariable weights, + Conv1DConfig Conv1DConfig) { + SDValidation.validateNumerical("conv1d", "input", input); + SDValidation.validateNumerical("conv1d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D(sd,input, weights, null, Conv1DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDValidation.validateNumerical("conv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, bias, Conv2DConfig).outputVariable(); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(String name, SDVariable layerInput, SDVariable weights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDValidation.validateNumerical("conv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(SDVariable layerInput, SDVariable weights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, null, Conv2DConfig).outputVariable(); + } + + /** + * 2D Convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param weights Weights for the convolution operation. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, outputChannels] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of conv2d op (NUMERIC type) + */ + public SDVariable conv2d(String name, SDVariable layerInput, SDVariable weights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("conv2d", "layerInput", layerInput); + SDValidation.validateNumerical("conv2d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D(sd,layerInput, weights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(SDVariable input, SDVariable weights, SDVariable bias, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDValidation.validateNumerical("conv3d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, bias, Conv3DConfig).outputVariable(); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDValidation.validateNumerical("conv3d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, bias, Conv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(SDVariable input, SDVariable weights, Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, null, Conv3DConfig).outputVariable(); + } + + /** + * Convolution 3D operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type) + * @param Conv3DConfig Configuration Object + * @return output Conv3d output variable (NUMERIC type) + */ + public SDVariable conv3d(String name, SDVariable input, SDVariable weights, + Conv3DConfig Conv3DConfig) { + SDValidation.validateNumerical("conv3d", "input", input); + SDValidation.validateNumerical("conv3d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D(sd,input, weights, null, Conv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, SDVariable bias, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDValidation.validateNumerical("deconv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, bias, DeConv2DConfig).outputVariable(); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(String name, SDVariable layerInput, SDVariable weights, + SDVariable bias, DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDValidation.validateNumerical("deconv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, bias, DeConv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(SDVariable layerInput, SDVariable weights, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, null, DeConv2DConfig).outputVariable(); + } + + /** + * 2D deconvolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type) + * @param DeConv2DConfig Configuration Object + * @return output result of deconv2d op (NUMERIC type) + */ + public SDVariable deconv2d(String name, SDVariable layerInput, SDVariable weights, + DeConv2DConfig DeConv2DConfig) { + SDValidation.validateNumerical("deconv2d", "layerInput", layerInput); + SDValidation.validateNumerical("deconv2d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D(sd,layerInput, weights, null, DeConv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDValidation.validateNumerical("deconv3d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, bias, DeConv3DConfig).outputVariable(); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param bias Bias array - optional, may be null. If non-null, must have shape [outputChannels] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, SDVariable bias, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDValidation.validateNumerical("deconv3d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, bias, DeConv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(SDVariable input, SDVariable weights, DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, null, DeConv3DConfig).outputVariable(); + } + + /** + * 3D CNN deconvolution operation with or without optional bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input array - shape [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) (NUMERIC type) + * @param weights Weights array - shape [kD, kH, kW, oC, iC] (NUMERIC type) + * @param DeConv3DConfig Configuration Object + * @return output result of 3D CNN deconvolution operation (NUMERIC type) + */ + public SDVariable deconv3d(String name, SDVariable input, SDVariable weights, + DeConv3DConfig DeConv3DConfig) { + SDValidation.validateNumerical("deconv3d", "input", input); + SDValidation.validateNumerical("deconv3d", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3D(sd,input, weights, null, DeConv3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input channels dimension by rearranging data into a larger spatial dimensions
+ * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable depthToSpace(SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("depthToSpace", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(sd,x, blockSize, dataFormat).outputVariable(); + } + + /** + * Convolution 2d layer batch to space operation on 4d input.
+ * Reduces input channels dimension by rearranging data into a larger spatial dimensions
+ * Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable depthToSpace(String name, SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("depthToSpace", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace(sd,x, blockSize, dataFormat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, SDVariable bias, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("depthWiseConv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, bias, Conv2DConfig).outputVariable(); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("depthWiseConv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(SDVariable layerInput, SDVariable depthWeights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, null, Conv2DConfig).outputVariable(); + } + + /** + * Depth-wise 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (NUMERIC type) + * @param depthWeights Depth-wise conv2d weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of depthwise conv2d op (NUMERIC type) + */ + public SDVariable depthWiseConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("depthWiseConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("depthWiseConv2d", "depthWeights", depthWeights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D(sd,layerInput, depthWeights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * TODO doc string
+ * + * @param df (NUMERIC type) + * @param weights df (NUMERIC type) + * @param strides weights (Size: Exactly(count=2)) + * @param rates strides (Size: Exactly(count=2)) + * @param isSameMode isSameMode + * @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type) + */ + public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, int[] rates, + boolean isSameMode) { + SDValidation.validateNumerical("dilation2D", "df", df); + SDValidation.validateNumerical("dilation2D", "weights", weights); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(sd,df, weights, strides, rates, isSameMode).outputVariable(); + } + + /** + * TODO doc string
+ * + * @param name name May be null. Name for the output variable + * @param df (NUMERIC type) + * @param weights df (NUMERIC type) + * @param strides weights (Size: Exactly(count=2)) + * @param rates strides (Size: Exactly(count=2)) + * @param isSameMode isSameMode + * @return output Computed the grayscale dilation of 4-D input and 3-D filters tensors. (NUMERIC type) + */ + public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, + int[] rates, boolean isSameMode) { + SDValidation.validateNumerical("dilation2D", "df", df); + SDValidation.validateNumerical("dilation2D", "weights", weights); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length == 2, "rates has incorrect size/length. Expected: rates.length == 2, got %s", rates.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D(sd,df, weights, strides, rates, isSameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Extract image patches
+ * + * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type) + * @param kH Kernel height + * @param kW Kernel width + * @param sH Stride height + * @param sW Stride width + * @param rH Rate height + * @param rW Rate width + * @param sameMode If true: use same mode padding. If false + * @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type) + */ + public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, + int rW, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "input", input); + return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,input, kH, kW, sH, sW, rH, rW, sameMode).outputVariable(); + } + + /** + * Extract image patches
+ * + * @param name name May be null. Name for the output variable + * @param input Input array. Must be rank 4, with shape [minibatch, height, width, channels] (NUMERIC type) + * @param kH Kernel height + * @param kW Kernel width + * @param sH Stride height + * @param sW Stride width + * @param rH Rate height + * @param rW Rate width + * @param sameMode If true: use same mode padding. If false + * @return output The result is a 4D tensor which is indexed by batch, row, and column. (NUMERIC type) + */ + public SDVariable extractImagePatches(String name, SDVariable input, int kH, int kW, int sH, + int sW, int rH, int rW, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,input, kH, kW, sH, sW, rH, rW, sameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
+ * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
+ * + * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Im2Col output variable (NUMERIC type) + */ + public SDVariable im2Col(SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("im2Col", "in", in); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(sd,in, Conv2DConfig).outputVariable(); + } + + /** + * im2col operation for use in 2D convolution operations. Outputs a 6d array with shape
+ * [minibatch, inputChannels, kernelHeight, kernelWidth, outputHeight, outputWidth]
+ * + * @param name name May be null. Name for the output variable + * @param in Input - rank 4 input with shape [minibatch, inputChannels, height, width] (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output Im2Col output variable (NUMERIC type) + */ + public SDVariable im2Col(String name, SDVariable in, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("im2Col", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col(sd,in, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D convolution layer operation - local response normalization
+ * + * @param input the inputs to lrn (NUMERIC type) + * @param LocalResponseNormalizationConfig Configuration Object + * @return output Result after Local Response Normalization (NUMERIC type) + */ + public SDVariable localResponseNormalization(SDVariable input, + LocalResponseNormalizationConfig LocalResponseNormalizationConfig) { + SDValidation.validateNumerical("localResponseNormalization", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(sd,input, LocalResponseNormalizationConfig).outputVariable(); + } + + /** + * 2D convolution layer operation - local response normalization
+ * + * @param name name May be null. Name for the output variable + * @param input the inputs to lrn (NUMERIC type) + * @param LocalResponseNormalizationConfig Configuration Object + * @return output Result after Local Response Normalization (NUMERIC type) + */ + public SDVariable localResponseNormalization(String name, SDVariable input, + LocalResponseNormalizationConfig LocalResponseNormalizationConfig) { + SDValidation.validateNumerical("localResponseNormalization", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization(sd,input, LocalResponseNormalizationConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution layer operation - max pooling 2d
+ * + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPooling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(sd,input, Pooling2DConfig).outputVariable(); + } + + /** + * 2D Convolution layer operation - max pooling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param Pooling2DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling2d(String name, SDVariable input, Pooling2DConfig Pooling2DConfig) { + SDValidation.validateNumerical("maxPooling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D(sd,input, Pooling2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 3D convolution layer operation - max pooling 3d operation.
+ * + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("maxPooling3d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(sd,input, Pooling3DConfig).outputVariable(); + } + + /** + * 3D convolution layer operation - max pooling 3d operation.
+ * + * @param name name May be null. Name for the output variable + * @param input the input to average pooling 3d operation - 5d activations in NCDHW format + * (shape [minibatch, channels, depth, height, width]) or NDHWC format + * (shape [minibatch, depth, height, width, channels]) (NUMERIC type) + * @param Pooling3DConfig Configuration Object + * @return output Result after applying max pooling on the input (NUMERIC type) + */ + public SDVariable maxPooling3d(String name, SDVariable input, Pooling3DConfig Pooling3DConfig) { + SDValidation.validateNumerical("maxPooling3d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling3D(sd,input, Pooling3DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDValidation.validateNumerical("separableConv2d", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, bias, Conv2DConfig).outputVariable(); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, SDVariable bias, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDValidation.validateNumerical("separableConv2d", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, bias, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, null, Conv2DConfig).outputVariable(); + } + + /** + * Separable 2D convolution operation with optional bias
+ * + * @param name name May be null. Name for the output variable + * @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type) + * @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type) + * @param Conv2DConfig Configuration Object + * @return output result of separable convolution 2d operation (NUMERIC type) + */ + public SDVariable separableConv2d(String name, SDVariable layerInput, SDVariable depthWeights, + SDVariable pointWeights, Conv2DConfig Conv2DConfig) { + SDValidation.validateNumerical("separableConv2d", "layerInput", layerInput); + SDValidation.validateNumerical("separableConv2d", "depthWeights", depthWeights); + SDValidation.validateNumerical("separableConv2d", "pointWeights", pointWeights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2D(sd,layerInput, depthWeights, pointWeights, null, Conv2DConfig).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer space to batch operation on 4d input.
+ * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
+ * + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToBatch(SDVariable x, int[] blocks, int[] paddingTop, + int... paddingBottom) { + SDValidation.validateNumerical("spaceToBatch", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length); + Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(sd,x, blocks, paddingTop, paddingBottom).outputVariable(); + } + + /** + * Convolution 2d layer space to batch operation on 4d input.
+ * Increases input batch dimension by rearranging data from spatial dimensions into batch dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable. 4d input (NUMERIC type) + * @param blocks Block size, in the height/width dimension (Size: Exactly(count=2)) + * @param paddingTop Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @param paddingBottom Optional 2d int[] array for padding the result: values [[pad top, pad bottom], [pad left, pad right]] (Size: Exactly(count=2)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToBatch(String name, SDVariable x, int[] blocks, int[] paddingTop, + int... paddingBottom) { + SDValidation.validateNumerical("spaceToBatch", "x", x); + Preconditions.checkArgument(blocks.length == 2, "blocks has incorrect size/length. Expected: blocks.length == 2, got %s", blocks.length); + Preconditions.checkArgument(paddingTop.length == 2, "paddingTop has incorrect size/length. Expected: paddingTop.length == 2, got %s", paddingTop.length); + Preconditions.checkArgument(paddingBottom.length == 2, "paddingBottom has incorrect size/length. Expected: paddingBottom.length == 2, got %s", paddingBottom.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch(sd,x, blocks, paddingTop, paddingBottom).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Convolution 2d layer space to depth operation on 4d input.
+ * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
+ * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToDepth(SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("spaceToDepth", "x", x); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(sd,x, blockSize, dataFormat).outputVariable(); + } + + /** + * Convolution 2d layer space to depth operation on 4d input.
+ * Increases input channels (reduced spatial dimensions) by rearranging data into a larger channels dimension
+ * Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]
+ * = [mb, 2, 4, 4]
+ * + * @param name name May be null. Name for the output variable + * @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format + * (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type) + * @param blockSize Block size, in the height/width dimension + * @param dataFormat Data format: "NCHW" or "NHWC" + * @return output Output variable (NUMERIC type) + */ + public SDVariable spaceToDepth(String name, SDVariable x, int blockSize, DataFormat dataFormat) { + SDValidation.validateNumerical("spaceToDepth", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth(sd,x, blockSize, dataFormat).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Upsampling layer for 2D inputs.
+ * scale is used for both height and width dimensions.
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scale The scale for both height and width dimensions. + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(SDVariable input, int scale) { + SDValidation.validateNumerical("upsampling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scale).outputVariable(); + } + + /** + * Upsampling layer for 2D inputs.
+ * scale is used for both height and width dimensions.
+ * + * @param name name May be null. Name for the output variable + * @param input Input in NCHW format (NUMERIC type) + * @param scale The scale for both height and width dimensions. + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(String name, SDVariable input, int scale) { + SDValidation.validateNumerical("upsampling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scale).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * 2D Convolution layer operation - Upsampling 2d
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(SDVariable input, int scaleH, int scaleW, boolean nchw) { + SDValidation.validateNumerical("upsampling2d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable(); + } + + /** + * 2D Convolution layer operation - Upsampling 2d
+ * + * @param name name May be null. Name for the output variable + * @param input Input in NCHW format (NUMERIC type) + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @param nchw If true: input is in NCHW (minibatch, channels, height, width) format. False: NHWC format + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling2d(String name, SDVariable input, int scaleH, int scaleW, + boolean nchw) { + SDValidation.validateNumerical("upsampling2d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 7b662b960..70940863a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -1,185 +1,440 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.custom.*; -import org.nd4j.linalg.api.ops.impl.image.CropAndResize; -import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; +import org.nd4j.base.Preconditions; -/** - * @author Alex Black - */ public class SDImage extends SDOps { - public SDImage(SameDiff sameDiff) { - super(sameDiff); - } + public SDImage(SameDiff sameDiff) { + super(sameDiff); + } - /** - * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size. - * - * @param name May be null. Name for the output variable. - * @param image Input image, with shape [batch, height, width, channels] - * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 - * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] - * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] - * @param method Image resize method - * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default - * @return Cropped and resized images - */ - public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize, - CropAndResize.Method method, double extrapolationValue) { - SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize, double extrapolationValue) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable(); + } - /** - * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. - * - * @param name Map be null. Name for the output variable - * @param image Input image to extract image patches from - shape [batch, height, width, channels] - * @param kSizes Kernel size - size of the image patches, [height, width] - * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] - * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels - * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken - * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension - * @param sameMode Padding algorithm. If true: use Same padding - * @return The extracted image patches - */ - public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes, - @NonNull int[] strides, @NonNull int[] rates, boolean sameMode) { - SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, + SDVariable boxIndices, SDVariable cropOutSize, double extrapolationValue) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Greedily selects a subset of bounding boxes in descending order of score - * @param name Might be null. Name for the output variable - * @param boxes 2D array of shape [num_boxes,4] - * @param scores vector of shape [num_boxes] - * @param maxOutSize scalar representing the maximum number of boxes to be selected - * @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU - * @param scoreThreshold float - threshold for deciding when to remove boxes based on score - * @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size - */ - public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, - @NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ - SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable(); + } - /** - * Adjusts contrast of RGB or grayscale images. - * @param name name for the output variable - * @param in images to adjust. 3D shape or higher. - * @param factor float multiplier for adjusting contrast. - * @return Contrast-adjusted image - */ - public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { - SDVariable out = new AdjustContrast(sd, in, factor).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type) + * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type) + * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type) + * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type) + * @return output Cropped and resized images (NUMERIC type) + */ + public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, + SDVariable boxIndices, SDVariable cropOutSize) { + SDValidation.validateNumerical("CropAndResize", "image", image); + SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes); + SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices); + SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Adjust saturation of RGB images - * @param name name for the output variable - * @param in RGB image as 3D array - * @param factor factor for saturation - * @return adjusted image - */ - public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { - SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjusts contrast of RGB or grayscale images.
+ * + * @param in images to adjust. 3D shape or higher (NUMERIC type) + * @param factor multiplier for adjusting contrast + * @return output Contrast-adjusted image (NUMERIC type) + */ + public SDVariable adjustContrast(SDVariable in, double factor) { + SDValidation.validateNumerical("adjustContrast", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable(); + } - /** - * Adjust hue of RGB image - * @param name name for the output variable - * @param in RGB image as 3D array - * @param delta value to add to hue channel - * @return adjusted image - */ - public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) { - SDVariable out = new AdjustHue(sd, in, delta).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjusts contrast of RGB or grayscale images.
+ * + * @param name name May be null. Name for the output variable + * @param in images to adjust. 3D shape or higher (NUMERIC type) + * @param factor multiplier for adjusting contrast + * @return output Contrast-adjusted image (NUMERIC type) + */ + public SDVariable adjustContrast(String name, SDVariable in, double factor) { + SDValidation.validateNumerical("adjustContrast", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Randomly crops image - * @param name name for the output variable - * @param input input array - * @param shape shape for crop - * @return cropped array - */ - public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) { - SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust hue of RGB image
+ * + * @param in image as 3D array (NUMERIC type) + * @param delta value to add to hue channel + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustHue(SDVariable in, double delta) { + SDValidation.validateNumerical("adjustHue", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable(); + } - /** - * Converting array from HSV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToHsv(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToHsv(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust hue of RGB image
+ * + * @param name name May be null. Name for the output variable + * @param in image as 3D array (NUMERIC type) + * @param delta value to add to hue channel + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustHue(String name, SDVariable in, double delta) { + SDValidation.validateNumerical("adjustHue", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from HSV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable hsvToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new HsvToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust saturation of RGB images
+ * + * @param in RGB image as 3D array (NUMERIC type) + * @param factor factor for saturation + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustSaturation(SDVariable in, double factor) { + SDValidation.validateNumerical("adjustSaturation", "in", in); + return new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable(); + } - /** - * Converting array from RGB to YIQ format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToYiq(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToYiq(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Adjust saturation of RGB images
+ * + * @param name name May be null. Name for the output variable + * @param in RGB image as 3D array (NUMERIC type) + * @param factor factor for saturation + * @return output adjusted image (NUMERIC type) + */ + public SDVariable adjustSaturation(String name, SDVariable in, double factor) { + SDValidation.validateNumerical("adjustSaturation", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from YIQ to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable yiqToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new YiqToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
+ * + * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type) + * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2)) + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2)) + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0)) + * @param sameMode Padding algorithm. If true: use Same padding + * @return output The extracted image patches (NUMERIC type) + */ + public SDVariable extractImagePatches(SDVariable image, int[] kSizes, int[] strides, int[] rates, + boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "image", image); + Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length); + return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable(); + } - /** - * Converting array from RGB to YUV format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable rgbToYuv(String name, @NonNull SDVariable input) { - SDVariable out = new RgbToYuv(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
+ * + * @param name name May be null. Name for the output variable + * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type) + * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2)) + * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2)) + * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels + * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken + * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0)) + * @param sameMode Padding algorithm. If true: use Same padding + * @return output The extracted image patches (NUMERIC type) + */ + public SDVariable extractImagePatches(String name, SDVariable image, int[] kSizes, int[] strides, + int[] rates, boolean sameMode) { + SDValidation.validateNumerical("extractImagePatches", "image", image); + Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length); + Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length); + Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Converting image from YUV to RGB format - * @param name name - * @param input 3D image - * @return 3D image - */ - public SDVariable yuvToRgb(String name, @NonNull SDVariable input) { - SDVariable out = new YuvToRgb(sd, input).outputVariable(); - return updateVariableNameAndReference(out, name); - } + /** + * Converting image from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable hsvToRgb(SDVariable input) { + SDValidation.validateNumerical("hsvToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from HSV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable hsvToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("hsvToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Greedily selects a subset of bounding boxes in descending order of score
+ * + * @param boxes Might be null. Name for the output variable (NUMERIC type) + * @param scores vector of shape [num_boxes] (NUMERIC type) + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold threshold for deciding when to remove boxes based on score + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + */ + public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize, + double iouThreshold, double scoreThreshold) { + SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes); + SDValidation.validateNumerical("nonMaxSuppression", "scores", scores); + return new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); + } + + /** + * Greedily selects a subset of bounding boxes in descending order of score
+ * + * @param name name May be null. Name for the output variable + * @param boxes Might be null. Name for the output variable (NUMERIC type) + * @param scores vector of shape [num_boxes] (NUMERIC type) + * @param maxOutSize scalar representing the maximum number of boxes to be selected + * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU + * @param scoreThreshold threshold for deciding when to remove boxes based on score + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + */ + public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores, + int maxOutSize, double iouThreshold, double scoreThreshold) { + SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes); + SDValidation.validateNumerical("nonMaxSuppression", "scores", scores); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Randomly crops image
+ * + * @param input input array (NUMERIC type) + * @param shape shape for crop (INT type) + * @return output cropped array (NUMERIC type) + */ + public SDVariable randomCrop(SDVariable input, SDVariable shape) { + SDValidation.validateNumerical("randomCrop", "input", input); + SDValidation.validateInteger("randomCrop", "shape", shape); + return new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable(); + } + + /** + * Randomly crops image
+ * + * @param name name May be null. Name for the output variable + * @param input input array (NUMERIC type) + * @param shape shape for crop (INT type) + * @return output cropped array (NUMERIC type) + */ + public SDVariable randomCrop(String name, SDVariable input, SDVariable shape) { + SDValidation.validateNumerical("randomCrop", "input", input); + SDValidation.validateInteger("randomCrop", "shape", shape); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from HSV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToHsv(SDVariable input) { + SDValidation.validateNumerical("rgbToHsv", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable(); + } + + /** + * Converting array from HSV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToHsv(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToHsv", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YIQ format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYiq(SDVariable input) { + SDValidation.validateNumerical("rgbToYiq", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable(); + } + + /** + * Converting array from RGB to YIQ format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYiq(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToYiq", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YUV format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYuv(SDVariable input) { + SDValidation.validateNumerical("rgbToYuv", "input", input); + return new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable(); + } + + /** + * Converting array from RGB to YUV format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable rgbToYuv(String name, SDVariable input) { + SDValidation.validateNumerical("rgbToYuv", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YIQ to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yiqToRgb(SDVariable input) { + SDValidation.validateNumerical("yiqToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from YIQ to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yiqToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("yiqToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YUV to RGB format
+ * + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yuvToRgb(SDVariable input) { + SDValidation.validateNumerical("yuvToRgb", "input", input); + return new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable(); + } + + /** + * Converting image from YUV to RGB format
+ * + * @param name name May be null. Name for the output variable + * @param input 3D image (NUMERIC type) + * @return output 3D image (NUMERIC type) + */ + public SDVariable yuvToRgb(String name, SDVariable input) { + SDValidation.validateNumerical("yuvToRgb", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java new file mode 100644 index 000000000..8dbb9d3b3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java @@ -0,0 +1,561 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.autodiff.samediff.ops; + +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +public class SDLinalg extends SDOps { + public SDLinalg(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public SDVariable cholesky(SDVariable input) { + SDValidation.validateNumerical("Cholesky", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable(); + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public SDVariable cholesky(String name, SDVariable input) { + SDValidation.validateNumerical("Cholesky", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer, boolean fast) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable(); + } + + /** + * Solver for linear squares problems.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer, + boolean fast) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable(); + } + + /** + * Solver for linear squares problems.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer) { + SDValidation.validateNumerical("Lstsq", "matrix", matrix); + SDValidation.validateNumerical("Lstsq", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Computes LU decomposition.
+ * + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable lu(SDVariable input) { + SDValidation.validateNumerical("Lu", "input", input); + return new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable(); + } + + /** + * Computes LU decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable lu(String name, SDVariable input) { + SDValidation.validateNumerical("Lu", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable matmul(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("Matmul", "a", a); + SDValidation.validateNumerical("Matmul", "b", b); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable(); + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param name name May be null. Name for the output variable + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable matmul(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("Matmul", "a", a); + SDValidation.validateNumerical("Matmul", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public SDVariable[] matrixBandPart(SDVariable input, int minLower, int maxUpper) { + SDValidation.validateNumerical("MatrixBandPart", "input", input); + return new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables(); + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public SDVariable[] matrixBandPart(String[] names, SDVariable input, int minLower, int maxUpper) { + SDValidation.validateNumerical("MatrixBandPart", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public SDVariable[] qr(SDVariable input, boolean full) { + SDValidation.validateNumerical("Qr", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables(); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public SDVariable[] qr(String[] names, SDVariable input, boolean full) { + SDValidation.validateNumerical("Qr", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + */ + public SDVariable[] qr(SDVariable input) { + SDValidation.validateNumerical("Qr", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables(); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input input tensor (NUMERIC type) + */ + public SDVariable[] qr(String[] names, SDVariable input) { + SDValidation.validateNumerical("Qr", "input", input); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(SDVariable matrix, SDVariable rhs, boolean adjoint) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable(); + } + + /** + * Solver for systems of linear equations.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(String name, SDVariable matrix, SDVariable rhs, boolean adjoint) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(SDVariable matrix, SDVariable rhs) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable(); + } + + /** + * Solver for systems of linear equations.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public SDVariable solve(String name, SDVariable matrix, SDVariable rhs) { + SDValidation.validateNumerical("Solve", "matrix", matrix); + SDValidation.validateNumerical("Solve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Solver for systems of linear questions.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public SDVariable triangularSolve(SDVariable matrix, SDVariable rhs, boolean lower, + boolean adjoint) { + SDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + SDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + return new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable(); + } + + /** + * Solver for systems of linear questions.
+ * + * @param name name May be null. Name for the output variable + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public SDVariable triangularSolve(String name, SDVariable matrix, SDVariable rhs, boolean lower, + boolean adjoint) { + SDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + SDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + SDVariable out = new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Computes pairwise cross product.
+ * + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable cross(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + } + + /** + * Computes pairwise cross product.
+ * + * @param name name May be null. Name for the output variable + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable cross(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag(SDVariable input) { + SDValidation.validateNumerical("diag", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable(); + } + + /** + * Calculates diagonal tensor.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag(String name, SDVariable input) { + SDValidation.validateNumerical("diag", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag_part(SDVariable input) { + SDValidation.validateNumerical("diag_part", "input", input); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable(); + } + + /** + * Calculates diagonal tensor.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable diag_part(String name, SDVariable input) { + SDValidation.validateNumerical("diag_part", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates log of determinant.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable logdet(SDVariable input) { + SDValidation.validateNumerical("logdet", "input", input); + return new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable(); + } + + /** + * Calculates log of determinant.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable logdet(String name, SDVariable input) { + SDValidation.validateNumerical("logdet", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX, + boolean transposeY, boolean transposeZ) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable mmul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mmul", "x", x); + SDValidation.validateNumerical("mmul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV, int switchNum) { + SDValidation.validateNumerical("svd", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable(); + } + + /** + * Calculates singular value decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV, + int switchNum) { + SDValidation.validateNumerical("svd", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV) { + SDValidation.validateNumerical("svd", "input", input); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable(); + } + + /** + * Calculates singular value decomposition.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV) { + SDValidation.validateNumerical("svd", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index f0e94a4e5..9a1ef1249 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -1,5 +1,5 @@ -/* ***************************************************************************** - * Copyright (c) 2015-2019 Skymind, Inc. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,544 +14,1045 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.ops.impl.loss.LogLoss; -import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; -import org.nd4j.linalg.factory.Nd4j; -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * SameDiff loss functions
- * Accessible via {@link SameDiff#loss()} - * - * @author Alex Black - */ -@SuppressWarnings("unused") public class SDLoss extends SDOps { - public SDLoss(SameDiff sameDiff) { - super(sameDiff); - } + public SDLoss(SameDiff sameDiff) { + super(sameDiff); + } - /** - * helper to refactor duplicate code - */ - private SDVariable getWeights(SDVariable weights, String name, SDVariable predictions){ - String weightName = (name == null) ? null : name + "/weight"; - return (weights == null) ? sd.constant(weightName, Nd4j.scalar(predictions.dataType(), 1.0)) : weights; - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return absoluteDifference(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] ) - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("absolute difference loss", "predictions", predictions); - validateNumerical("absolute difference loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossAbsoluteDifference(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #absoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable absoluteDifference(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return absoluteDifference(name, label, predictions, null, lossReduce); - } + /** + * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output loss variable (NUMERIC type) + */ + public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("absoluteDifference", "label", label); + SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); + SDValidation.validateNumerical("absoluteDifference", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #cosineDistance(String, SDVariable, SDVariable, SDVariable, LossReduce, int)}. - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, int dimension) { - return cosineDistance(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is - * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm. - * If this is not the case, you should normalize them first by dividing by {@link SameDiff#norm2(String, SDVariable, boolean, int...)} - * along the cosine distance dimension (with keepDims=true). - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over - * @return Cosine distance loss variable - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, int dimension) { - validateFloatingPoint("cosine distance loss", "predictions", predictions); - validateNumerical("cosine distance loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossCosineDistance(label, predictions, weights, lossReduce, dimension); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #cosineDistance(String, SDVariable, SDVariable, SDVariable, LossReduce, int)}. - */ - public SDVariable cosineDistance(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - @NonNull LossReduce lossReduce, int dimension) { - return cosineDistance(name, label, predictions, null, lossReduce, dimension); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, + int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #hingeLoss(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return hingeLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
+ * equivalent to cosine distance when both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
+ * along the cosine distance dimension (with keepDims=true).
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over + * @return output Cosine distance loss (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, + SDVariable weights, int dimension) { + SDValidation.validateNumerical("cosineDistance", "label", label); + SDValidation.validateNumerical("cosineDistance", "predictions", predictions); + SDValidation.validateNumerical("cosineDistance", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Hinge loss: a loss function used for training classifiers. - * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1} - * from the user specified {0,1}. Note that Labels should be provided with values {0,1}. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("hinge loss", "predictions", predictions); - validateNumerical("hinge loss", "labels", label); - if (weights == null) - weights = sd.scalar(null, predictions.dataType(), 1.0); - SDVariable result = f().lossHinge(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #hingeLoss(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable hingeLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return hingeLoss(name, label, predictions, null, lossReduce); - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #huberLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, double delta) { - return huberLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta); - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss, - * though is less sensitive to outliers than squared error.
- * Huber loss implements: - *
-     * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta
-     *  L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise
-     *     }
-     * 
- * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value - * @return Huber loss variable - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, double delta) { - validateFloatingPoint("huber loss", "predictions", predictions); - validateNumerical("huber loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossHuber(label, predictions, weights, lossReduce, delta); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Hinge loss: a loss function used for training classifiers.
+ * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
+ * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("hingeLoss", "label", label); + SDValidation.validateNumerical("hingeLoss", "predictions", predictions); + SDValidation.validateNumerical("hingeLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #huberLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable huberLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce, double delta) { - return huberLoss(name, label, predictions, null, lossReduce, delta); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * L2 loss: 1/2 * sum(x^2) - * - * @param var Variable to calculate L2 loss of - * @return L2 loss - */ - public SDVariable l2Loss(@NonNull SDVariable var) { - return l2Loss(null, var); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * L2 loss: 1/2 * sum(x^2) - * - * @param name Name of the output variable - * @param var Variable to calculate L2 loss of - * @return L2 loss - */ - public SDVariable l2Loss(String name, @NonNull SDVariable var) { - validateNumerical("l2 loss", var); - SDVariable result = f().lossL2(var); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, + double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logLoss(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, LogLoss.DEFAULT_EPSILON); - } + /** + * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
+ * though is less sensitive to outliers than squared error.
+ * Huber loss implements:
+ *

+ * {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
+ * {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value + * @return output Huber loss (NUMERIC type) + */ + public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, double delta) { + SDValidation.validateNumerical("huberLoss", "label", label); + SDValidation.validateNumerical("huberLoss", "predictions", predictions); + SDValidation.validateNumerical("huberLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements: - * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))} - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Log loss variable - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce, double epsilon) { - validateFloatingPoint("log loss", "predictions", predictions); - validateNumerical("log loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLog(label, predictions, weights, lossReduce, epsilon); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * L2 loss: 1/2 * sum(x^2)
+ * + * @param var Variable to calculate L2 loss of (NUMERIC type) + * @return output L2 loss (NUMERIC type) + */ + public SDVariable l2Loss(SDVariable var) { + SDValidation.validateNumerical("l2Loss", "var", var); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logLoss(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable logLoss(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logLoss(name, label, predictions, null, lossReduce, LogLoss.DEFAULT_EPSILON); - } + /** + * L2 loss: 1/2 * sum(x^2)
+ * + * @param name name May be null. Name for the output variable + * @param var Variable to calculate L2 loss of (NUMERIC type) + * @return output L2 loss (NUMERIC type) + */ + public SDVariable l2Loss(String name, SDVariable var) { + SDValidation.validateNumerical("l2Loss", "var", var); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #logPoisson(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logPoisson(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double epsilon) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDValidation.validateNumerical("logLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Log poisson loss: a loss function used for training classifiers. - * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 - * @param predictions Predictions array (has to be log(x) of actual predictions) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("log poisson loss", "predictions", predictions); - validateNumerical("log poisson loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLogPoisson(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, double epsilon) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDValidation.validateNumerical("logLoss", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #logPoisson(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoisson(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logPoisson(name, label, predictions, null, lossReduce); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(SDVariable label, SDVariable predictions) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logPoissonFull(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return logPoissonFull(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @return output Log loss (NUMERIC type) + */ + public SDVariable logLoss(String name, SDVariable label, SDVariable predictions) { + SDValidation.validateNumerical("logLoss", "label", label); + SDValidation.validateNumerical("logLoss", "predictions", predictions); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Log poisson loss: a loss function used for training classifiers. - * Implements {@code L = exp(c) - z * c + z * log(z) - z + 0.5 * log(2 * pi * z)} - * where c is log(predictions) and z is labels. - * - * @param name Name of the operation - * @param label Label array. Each value should be 0.0 or 1.0 - * @param predictions Predictions array (has to be log(x) of actual predictions) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("log poisson (full) loss", "predictions", predictions); - validateNumerical("log poisson (full) loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossLogPoissonFull(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #logPoissonFull(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable logPoissonFull(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return logPoissonFull(name, label, predictions, null, lossReduce); - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanPairwiseSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return meanPairwiseSquaredError(name, label, predictions, null, lossReduce); - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, + boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays. - * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is: - * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
- * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] - * @return Loss variable, scalar output - */ - public SDVariable meanPairwiseSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("main pairwise squared error loss", "predictions", predictions); - validateNumerical("mean pairwise squared error loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossMeanPairwiseSquaredError(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Log poisson loss: a loss function used for training classifiers.
+ * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @return output Loss variable (NUMERIC type) + */ + public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, + SDVariable weights, boolean full) { + SDValidation.validateNumerical("logPoisson", "label", label); + SDValidation.validateNumerical("logPoisson", "predictions", predictions); + SDValidation.validateNumerical("logPoisson", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return meanSquaredError(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis. - * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default)) - * this is the mean squared error loss function. - * - * @param name Name of the operation - * @param label Label array - * @param predictions Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, - SDVariable weights, @NonNull LossReduce lossReduce) { - validateFloatingPoint("mean squared error loss", "predictions", predictions); - validateNumerical("mean squared error loss", "labels", label); - weights = getWeights(weights, name, predictions); - SDVariable result = f().lossMeanSquaredError(label, predictions, weights, lossReduce); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #meanSquaredError(String, SDVariable, SDVariable, SDVariable, LossReduce)}. - */ - public SDVariable meanSquaredError(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return meanSquaredError(name, label, predictions, null, lossReduce); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sigmoidCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return sigmoidCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, SigmoidCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean pairwise squared error.
+ * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
+ * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @return output Loss variable, scalar output (NUMERIC type) + */ + public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); + SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions") - * and implements the binary cross entropy loss function. This implementation is numerically more stable than using - * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements: - * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))} - * though this is done in a mathematically equivalent but more numerical stable form.
- *
- * When label smoothing is > 0, the following label smoothing is used:
- *
-     * {@code numClasses = labels.size(1);
-     * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
-     * 
- * - * @param name Name of the operation - * @param label Label array - * @param predictionLogits Predictions array - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @return Loss variable - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictionLogits, - SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { - validateFloatingPoint("sigmoid cross entropy loss", "predictions", predictionLogits); - validateNumerical("sigmoid cross entropy loss", "labels", label); - weights = getWeights(weights, name, predictionLogits); - SDVariable result = f().lossSigmoidCrossEntropy(label, predictionLogits, weights, lossReduce, labelSmoothing); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sigmoidCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable sigmoidCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return sigmoidCrossEntropy(name, label, predictions, null, lossReduce, SigmoidCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions) { - return softmaxCrossEntropy(name, label, predictions, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, SoftmaxCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels; - * otherwise, the output is a scalar.
- *

- * When label smoothing is > 0, the following label smoothing is used:
- *

-     * {@code numClasses = labels.size(1);
-     * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
-     * 
- * - * @param name Name of the operation - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) - * @param logitPredictions Predictions array (pre-softmax) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 - * @return Loss variable - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable oneHotLabels, @NonNull SDVariable logitPredictions, - SDVariable weights, @NonNull LossReduce lossReduce, double labelSmoothing) { - validateFloatingPoint("softmax cross entropy loss", "predictions", logitPredictions); - validateNumerical("softmax cross entropy loss", "oneHotLabels", oneHotLabels); - weights = getWeights(weights, name, logitPredictions); - SDVariable result = f().lossSoftmaxCrossEntropy(oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
+ * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
+ * this is the mean squared error loss function.
+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictions Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, + SDVariable weights) { + SDValidation.validateNumerical("meanSquaredError", "label", label); + SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); + SDValidation.validateNumerical("meanSquaredError", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, SDVariable, LossReduce, double)}. - */ - public SDVariable softmaxCrossEntropy(String name, @NonNull SDVariable label, @NonNull SDVariable predictions, @NonNull LossReduce lossReduce) { - return softmaxCrossEntropy(name, label, predictions, null, lossReduce, SoftmaxCrossEntropyLoss.DEFAULT_LABEL_SMOOTHING); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * See {@link #sparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)} - */ - public SDVariable sparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) { - return sparseSoftmaxCrossEntropy(null, logits, labels); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * As per {@link #softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable - * is represented as an integer array instead of the equivalent one-hot array.
- * i.e., if logits are rank N, then labels have rank N-1 - * - * @param name Name of the output variable. May be null - * @param logits Logits array ("pre-softmax activations") - * @param labels Labels array. Must be an integer type. - * @return Softmax cross entropy - */ - public SDVariable sparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) { - validateFloatingPoint("sparse softmax cross entropy", "logits (predictions)", logits); - validateInteger("sparse softmax cross entropy", "labels", labels); - Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits); - SDVariable result = f().lossSparseSoftmaxCrossEntropy(logits, labels); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, + SDVariable weights) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } - /** - * TODO - * - * @param targets - * @param inputs - * @param weights - * @return - */ - public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, - SDVariable weights) { - return weightedCrossEntropyWithLogits(null, targets, inputs, weights); - } + /** + * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
+ * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
+ * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
+ * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
+ * though this is done in a mathematically equivalent but more numerical stable form.
+ *
+ * When label smoothing is > 0, the following label smoothing is used:
+ *

+ * {@code numClasses = labels.size(1);
+ * label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) + * @param predictionLogits Predictions array (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, + SDVariable weights) { + SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); + SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); + SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * TODO - * - * @param name - * @param targets - * @param inputs - * @param weights - * @return - */ - public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, - SDVariable weights) { - validateFloatingPoint("weighted cross entropy with logits", "inputs", inputs); - validateNumerical("weighted cross entropy with logits", "targets", targets); - SDVariable result = f().weightedCrossEntropyWithLogits(targets, inputs, weights); - result = updateVariableNameAndReference(result, name); - result.markAsLoss(); - return result; - } + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, + SDVariable logitPredictions, SDVariable weights, LossReduce lossReduce, + double labelSmoothing) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, + SDVariable weights) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Applies the softmax activation function to the input, then implement multi-class cross entropy:
+ * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
+ * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
+ * otherwise, the output is a scalar.
+ *


+ * When label smoothing is > 0, the following label smoothing is used:
+ *


+ * {@code numClasses = labels.size(1);
+ * oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, + SDVariable logitPredictions, SDVariable weights) { + SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); + SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); + SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ * + * @param logits Logits array ("pre-softmax activations") (NUMERIC type) + * @param labels Labels array. Must be an integer type. (INT type) + * @return output Softmax cross entropy (NUMERIC type) + */ + public SDVariable sparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels) { + SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); + SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ * + * @param name name May be null. Name for the output variable + * @param logits Logits array ("pre-softmax activations") (NUMERIC type) + * @param labels Labels array. Must be an integer type. (INT type) + * @return output Softmax cross entropy (NUMERIC type) + */ + public SDVariable sparseSoftmaxCrossEntropy(String name, SDVariable logits, SDVariable labels) { + SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); + SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Weighted cross entropy loss with logits
+ * + * @param targets targets array (NUMERIC type) + * @param inputs input array (NUMERIC type) + * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, + SDVariable weights) { + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + out.markAsLoss(); + return out; + } + + /** + * Weighted cross entropy loss with logits
+ * + * @param name name May be null. Name for the output variable + * @param targets targets array (NUMERIC type) + * @param inputs input array (NUMERIC type) + * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @return output Loss variable (NUMERIC type) + */ + public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, + SDVariable inputs, SDVariable weights) { + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); + SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + out.markAsLoss(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 1e038e193..f4a490813 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,2539 +14,2955 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; -import org.nd4j.linalg.api.ops.impl.shape.Eye; import org.nd4j.linalg.indexing.conditions.Condition; -import java.util.List; - -import static org.nd4j.autodiff.samediff.ops.SDValidation.*; - -/** - * SameDiff math operations
- * Accessible via {@link SameDiff#math()} - * - * @author Alex Black - */ public class SDMath extends SDOps { - public SDMath(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * Elementwise absolute value operation: out = abs(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable abs(SDVariable x) { - return abs(null, x); - } - - /** - * Elementwise absolute value operation: out = abs(x) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable abs(String name, SDVariable x) { - validateNumerical("abs", x); - SDVariable result = f().abs(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable acos(SDVariable x) { - return acos(null, x); - } - - /** - * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable acos(String name, SDVariable x) { - validateNumerical("acos", x); - SDVariable result = f().acos(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable acosh(SDVariable x) { - return acosh(null, x); - } - - /** - * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable acosh(String name, SDVariable x) { - validateNumerical("acosh", x); - SDVariable result = f().acosh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amax(SDVariable in, int... dimensions) { - return amax(null, in, dimensions); - } - - /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amax(String name, SDVariable in, int... dimensions) { - validateNumerical("amax", in); - SDVariable ret = f().amax(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amean(SDVariable in, int... dimensions) { - validateNumerical("amean", in); - return amean(null, in, dimensions); - } - - /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amean(String name, SDVariable in, int... dimensions) { - validateNumerical("amean", in); - SDVariable ret = f().amean(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amin(SDVariable in, int... dimensions) { - return amin(null, in, dimensions); - } - - /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable amin(String name, SDVariable in, int... dimensions) { - validateNumerical("amin", in); - SDVariable ret = f().amin(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable and(SDVariable x, SDVariable y) { - return and(null, x, y); - } - - /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable and(String name, SDVariable x, SDVariable y) { - validateBool("boolean and", x, y); - SDVariable result = f().and(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable asin(SDVariable x) { - return asin(null, x); - } - - /** - * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable asin(String name, SDVariable x) { - validateNumerical("asin", x); - SDVariable result = f().asin(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable asinh(SDVariable x) { - return asinh(null, x); - } - - /** - * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable asinh(String name, SDVariable x) { - validateNumerical("asinh", x); - SDVariable result = f().asinh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable asum(SDVariable in, int... dimensions) { - return asum(null, in, dimensions); - } - - /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable asum(String name, SDVariable in, int... dimensions) { - validateNumerical("asum", in); - SDVariable ret = f().asum(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable atan(SDVariable x) { - return atan(null, x); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable atan(String name, SDVariable x) { - validateNumerical("atan", x); - SDVariable result = f().atan(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). - * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result - * - * @param y Input Y variable - * @param x Input X variable - * @return Output variable - */ - public SDVariable atan2(SDVariable y, SDVariable x) { - return atan2(null, y, x); - } - - /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y). - * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result - * - * @param name Name of the output variable - * @param y Input Y variable - * @param x Input X variable - * @return Output variable - */ - public SDVariable atan2(String name, SDVariable y, SDVariable x) { - validateNumerical("atan2", y, x); - SDVariable ret = f().atan2(y, x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable atanh(SDVariable x) { - return atanh(null, x); - } - - /** - * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable atanh(String name, SDVariable x) { - validateNumerical("atanh", x); - SDVariable result = f().atanh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise ceiling function: out = ceil(x). - * Rounds each value up to the nearest integer value (if not already an integer) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable ceil(SDVariable x) { - return ceil(null, x); - } - - /** - * Element-wise ceiling function: out = ceil(x). - * Rounds each value up to the nearest integer value (if not already an integer) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable ceil(String name, SDVariable x) { - validateFloatingPoint("ceil", x); - SDVariable ret = f().ceil(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Clipping by L2 norm
- * if l2Norm(x) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in) - * - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @return Output variable - */ - public SDVariable clipByNorm(SDVariable x, double clipValue) { - return clipByNorm(null, x, clipValue); - } - - /** - * Clipping by L2 norm
- * if l2Norm(x) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in) - * - * @param name Name of the output variable - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @return Output variable - */ - public SDVariable clipByNorm(String name, SDVariable x, double clipValue) { - validateFloatingPoint("clip by norm", x); - SDVariable ret = f().clipByNorm(x, clipValue); - return updateVariableNameAndReference(ret, name); - } - - /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according - * to the corresponding l2Norm along the specified dimensions - * - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions If not specified, all dimensions are used - * @return Output variable - */ - public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { - return clipByNorm(null, x, clipValue, dimensions); - } - - /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according - * to the corresponding l2Norm along the specified dimensions - * - * @param name Output variable name - * @param x Input variable - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions If not specified, all dimensions are used - * @return Output variable - */ - public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { - validateFloatingPoint("clip by norm", x); - SDVariable ret = f().clipByNorm(x, clipValue, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
- * - * @param x Input variable - * @param clipValueMin Minimum value for clipping - * @param clipValueMax Maximum value for clipping - * @return Output variable - */ - public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { - return clipByValue(null, x, clipValueMin, clipValueMax); - } - - /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
- * - * @param name Name of the output variable - * @param x Input variable - * @param clipValueMin Minimum value for clipping - * @param clipValueMax Maximum value for clipping - * @return Output variable - */ - public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { - validateNumerical("clip by value", x); - SDVariable ret = f().clipByValue(x, clipValueMin, clipValueMax); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions) { - return confusionMatrix((String) null, labels, predictions); - } - - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred) { - return confusionMatrix(name, labels, pred, ConfusionMatrix.DEFAULT_DTYPE); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, DataType dataType) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - SDVariable result = f().confusionMatrix(labels, pred, dataType); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, Integer) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { - return confusionMatrix(null, labels, pred, numClasses); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param numClasses Number of classes - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - SDVariable result = f().confusionMatrix(labels, pred, numClasses); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { - return confusionMatrix(null, labels, pred, weights); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3] - * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of - * each prediction. Must be same length as both labels and predictions arrays - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - validateNumerical("confusionMatrix", "weights", weights); - SDVariable result = f().confusionMatrix(labels, pred, weights); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #confusionMatrix(String, SDVariable, SDVariable, Integer, SDVariable) - */ - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - return confusionMatrix(null, labels, pred, numClasses, weights); - } - - /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of - * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3] - * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * - * @param name Name of the output variable - * @param labels Labels - 1D array of integer values representing label values - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of - * each prediction. Must be same length as both labels and predictions arrays - * @return Output variable (2D, shape [numClasses, numClasses}) - */ - public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - validateInteger("confusionMatrix", "labels", labels); - validateInteger("confusionMatrix", "prediction", pred); - validateNumerical("confusionMatrix", "weights", weights); - SDVariable result = f().confusionMatrix(labels, pred, numClasses, weights); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise cosine operation: out = cos(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cos(SDVariable x) { - return cos(null, x); - } - - /** - * Elementwise cosine operation: out = cos(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cos(String name, SDVariable x) { - validateNumerical("cos", x); - SDVariable result = f().cos(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cosh(SDVariable x) { - return cosh(null, x); - } - - /** - * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cosh(String name, SDVariable x) { - validateNumerical("cosh", x); - SDVariable result = f().cosh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cosineDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { - return cosineDistance(null, x, y, dimensions); - } - - /** - * Cosine distance reduction operation. The output contains the cosine distance for each - * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
- * See {@link #cosineSimilarity(String, SDVariable, SDVariable, int...)} - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("cosine distance", x, y); - SDVariable result = f().cosineDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #cosineSimilarity(String, SDVariable, SDVariable, int...) - */ - public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { - return cosineSimilarity(sd.generateNewVarName(CosineSimilarity.OP_NAME, 0), x, y, dimensions); - } - - /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset - * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2) - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("cosine similarity", x, y); - SDVariable cosim = f().cosineSimilarity(x, y, dimensions); - return updateVariableNameAndReference(cosim, name); - } - - /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) - * - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countNonZero(SDVariable input, int... dimensions) { - return countNonZero(null, input, dimensions); - } - - /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countNonZero(String name, SDVariable input, int... dimensions) { - validateNumerical("countNonZero", input); - SDVariable res = f().countNonZero(input, dimensions); - return updateVariableNameAndReference(res, name); - } - - /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) - * - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countZero(SDVariable input, int... dimensions) { - return countZero(null, input, dimensions); - } - - /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable countZero(String name, SDVariable input, int... dimensions) { - validateNumerical("countNonZero", input); - SDVariable res = f().countZero(input, dimensions); - return updateVariableNameAndReference(res, name); - } - - /** - * @see #cross(String, SDVariable, SDVariable) - */ - public SDVariable cross(SDVariable a, SDVariable b) { - return cross(null, a, b); - } - - /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta). - * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3 - * - * @param a First input - * @param b Second input - * @return Element-wise cross product - */ - public SDVariable cross(String name, SDVariable a, SDVariable b) { - validateNumerical("cross", a, b); - SDVariable ret = f().cross(a, b); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise cube function: out = x^3 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable cube(SDVariable x) { - return cube(null, x); - } - - /** - * Element-wise cube function: out = x^3 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable cube(String name, SDVariable x) { - validateNumerical("cube", x); - SDVariable result = f().cube(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #diag(String, SDVariable) - */ - public SDVariable diag(SDVariable x) { - return diag(null, x); - } - - /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- *
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k]. - * i.e., for input rank R, output has rank 2R - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable diag(String name, SDVariable x) { - SDVariable ret = f().diag(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #diagPart(String, SDVariable) - */ - public SDVariable diagPart(SDVariable x) { - return diagPart(null, x); - } - - /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k] - * - * @param x Input variable - * @return Diagonal part of the input - * @see #diag(String, SDVariable) - */ - public SDVariable diagPart(String name, SDVariable x) { - SDVariable ret = f().diagPart(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Entropy reduction: -sum(x * log(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable - */ - public SDVariable entropy(SDVariable in, int... dimensions) { - return entropy(null, in, dimensions); - } - - /** - * Entropy reduction: -sum(x * log(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable entropy(String name, SDVariable in, int... dimensions) { - validateNumerical("entropy reduction", in); - SDVariable ret = f().entropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise Gaussian error function - out = erf(in) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable erf(SDVariable x) { - return erf(null, x); - } - - /** - * Element-wise Gaussian error function - out = erf(in) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable erf(String name, SDVariable x) { - validateNumerical("erf (error function)", x); - SDVariable ret = f().erf(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable erfc(SDVariable x) { - return erfc(null, x); - } - - /** - * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in) - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable erfc(String name, SDVariable x) { - validateNumerical("erfc", x); - SDVariable ret = f().erfc(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #euclideanDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { - return euclideanDistance(sd.generateNewVarName(EuclideanDistance.OP_NAME, 0), x, y, dimensions); - } - - /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each - * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 ) - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("euclidean distance", x, y); - SDVariable result = f().euclideanDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise exponent function: out = exp(x) = 2.71828...^x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable exp(SDVariable x) { - return exp(null, x); - } - - /** - * Elementwise exponent function: out = exp(x) = 2.71828...^x - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable exp(String name, SDVariable x) { - validateNumerical("exp", x); - SDVariable result = f().exp(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable expm1(SDVariable x) { - return expm1(null, x); - } - - /** - * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable expm1(String name, SDVariable x) { - validateNumerical("expm1", x); - SDVariable result = f().expm1(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Generate a square identity matrix with the specified number of rows. - * - * @param rows Number of rows (and columns) - * @return SDVariable with an identity matrix array - */ - public SDVariable eye(int rows) { - return eye(rows, rows); - } - - /** - * Generate an identity matrix with the specified number of rows and columns. - * - * @param rows Number of rows - */ - public SDVariable eye(String name, int rows) { - return eye(name, rows, rows); - } - - /** - * @see #eye(String, int, int) - */ - public SDVariable eye(int rows, int cols) { - return eye(null, rows, cols); - } - - /** - * As per {@link #eye(String, int, int, DataType)} but with the default datatype, {@link Eye#DEFAULT_DTYPE} - */ - public SDVariable eye(String name, int rows, int cols) { - return eye(name, rows, cols, Eye.DEFAULT_DTYPE); - } - - /** - * Generate an identity matrix with the specified number of rows and columns - * Example:
- *
-     * {@code SDVariable eye = eye(3,2)
-     * eye:
-     * [ 1, 0]
-     * [ 0, 1]
-     * [ 0, 0]}
-     * 
- * - * @param name Name of the new SDVariable - * @param rows Number of rows - * @param cols Number of columns - * @return SDVaribable identity matrix - */ - public SDVariable eye(String name, int rows, int cols, DataType dataType) { - return eye(name, rows, cols, dataType); - } - - /** - * see {@link #eye(String, int, int, DataType, int...)} - */ - public SDVariable eye(int rows, int cols, DataType dataType, int... batchDimension) { - return eye(null, rows, cols, dataType, batchDimension); - } - - /** - * Generate an identity matrix with the specified number of rows and columns, with optional leading dims
- * Example:
- * batchShape: [3,3]
- * numRows: 2
- * numCols: 4
- * returns a tensor of shape (3, 3, 2, 4) that consists of 3 * 3 batches of (2,4)-shaped identity matrices:
- * 1 0 0 0
- * 0 1 0 0
- * - * @param rows Number of rows - * @param cols Number of columns - * @param batchDimension Batch dimensions. May be null - */ - public SDVariable eye(String name, int rows, int cols, DataType dataType, int... batchDimension) { - SDVariable eye = new Eye(sd, rows, cols, dataType, batchDimension).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int, int, DataType, int...)} bit with the number of rows/columns specified as scalar SDVariables, - * and the batch dimension specified as a 1D SDVariable - */ - public SDVariable eye(SDVariable rows, SDVariable cols, SDVariable batchDimension) { - return eye(null, rows, cols, batchDimension); - } - - /** - * As per {@link #eye(String, int, int, int...)} bit with the number of rows/columns specified as scalar SDVariables, - * and the batch dimension specified as a 1D SDVariable - */ - public SDVariable eye(String name, SDVariable rows, SDVariable cols, SDVariable batchDimension) { - SDVariable eye = new Eye(sd, rows, cols, batchDimension).outputVariable(); - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(String, int, int)} bit with the number of rows/columns specified as scalar SDVariables - */ - public SDVariable eye(String name, SDVariable rows, SDVariable cols) { - SDVariable eye = new Eye(sd, rows, cols).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int, int)} bit with the number of rows/columns specified as scalar SDVariables - */ - public SDVariable eye(SDVariable rows, SDVariable cols) { - SDVariable eye = new Eye(sd, rows, cols).outputVariables()[0]; - return updateVariableNameAndReference(eye, null); - } - - /** - * As per {@link #eye(String, int)} but with the number of rows specified as a scalar SDVariable - */ - public SDVariable eye(String name, SDVariable rows) { - SDVariable eye = new Eye(sd, rows).outputVariables()[0]; - return updateVariableNameAndReference(eye, name); - } - - /** - * As per {@link #eye(int)} but with the number of rows specified as a scalar SDVariable - */ - public SDVariable eye(SDVariable rows) { - SDVariable eye = new Eye(sd, rows).outputVariables()[0]; - return updateVariableNameAndReference(eye, null); - } - - /** - * @see #firstIndex(String, SDVariable, Condition, int...) - */ - public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { - return firstIndex(null, in, condition, dimensions); - } - - /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each - * slice along the specified dimensions) - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { - return firstIndex(name, in, condition, false, dimensions); - } - - /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each - * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - SDVariable ret = f().firstIndex(in, condition, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #firstIndex(String, SDVariable, Condition, boolean, int...) - */ - public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return firstIndex(null, in, condition, keepDims, dimensions); - } - - /** - * Element-wise floor function: out = floor(x). - * Rounds each value down to the nearest integer value (if not already an integer) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable floor(SDVariable x) { - return floor(null, x); - } - - /** - * Element-wise floor function: out = floor(x). - * Rounds each value down to the nearest integer value (if not already an integer) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable floor(String name, SDVariable x) { - validateFloatingPoint("floor", x); - SDVariable result = f().floor(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #hammingDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { - return hammingDistance(null, x, y, dimensions); - } - - /** - * Hamming distance reduction operation. The output contains the cosine distance for each - * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] ) - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("hamming distance reduction", x, y); - SDVariable result = f().hammingDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(SDVariable, int...) - */ - public SDVariable iamax(SDVariable in, int... dimensions) { - return iamax(null, in, dimensions); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(String name, SDVariable in, int... dimensions) { - return iamax(name, in, false, dimensions); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("iamax", in); - SDVariable ret = f().iamax(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Index of the max absolute value: argmax(abs(in)) - * - * @see SameDiff#argmax(String, SDVariable, boolean, int...) - */ - public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { - return iamax(null, in, keepDims, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(SDVariable in, int... dimensions) { - return iamin(null, in, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(String name, SDVariable in, int... dimensions) { - return iamin(name, in, false, dimensions); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { - validateNumerical("iamin", in); - SDVariable ret = f().iamin(in, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Index of the min absolute value: argmin(abs(in)) - * - * @see SameDiff#argmin(String, SDVariable, boolean, int...) - */ - public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { - return iamin(null, in, keepDims, dimensions); - } - - /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isFinite(SDVariable x) { - return isFinite(null, x); - } - - /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isFinite(String name, SDVariable x) { - validateFloatingPoint("isFinite", x); - SDVariable result = f().isFinite(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isInfinite(SDVariable x) { - return isInfinite(null, x); - } - - /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isInfinite(String name, SDVariable x) { - validateFloatingPoint("isInfinite", x); - SDVariable result = f().isInfinite(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isMax(SDVariable x) { - return isMax(null, x); - } - - /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Name of the output variable - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isMax(String name, SDVariable x) { - validateNumerical("isMax", x); - SDVariable ret = f().isMax(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isNaN(SDVariable x) { - return isNaN(null, x); - } - - /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or - * value 0 otherwise - * - * @param name Output variable name - * @param x Input array - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable isNaN(String name, SDVariable x) { - validateFloatingPoint("isNaN", x); - SDVariable result = f().isNaN(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param x Input variable - * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise - */ - public SDVariable isNonDecreasing(SDVariable x) { - return isNonDecreasing(null, x); - } - - /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param name Output name - * @param x Input variable - * @return Scalar variable with value 1 if non-decreasing, or 0 otherwise - */ - public SDVariable isNonDecreasing(String name, SDVariable x) { - validateNumerical("isNonDecreasing", x); - SDVariable result = f().isNonDecreasing(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param x Input variable - * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise - */ - public SDVariable isStrictlyIncreasing(SDVariable x) { - return isStrictlyIncreasing(null, x); - - } - - /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared - * in 'c' (row major) order - * - * @param name Output variable name - * @param x Input variable - * @return Scalar variable with value 1 if strictly increasing, or 0 otherwise - */ - public SDVariable isStrictlyIncreasing(String name, SDVariable x) { - validateNumerical("isStrictlyIncreasing", x); - SDVariable result = f().isStrictlyIncreasing(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Jaccard similarity reduction operation. The output contains the Jaccard distance for each - * tensor along the specified dimensions. - * - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate Jaccard similarity over - * @return Output variable - */ - public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { - return jaccardDistance(null, x, y, dimensions); - } - - /** - * Jaccard similarity reduction operation. The output contains the Jaccard distance for each - * tensor along the specified dimensions. - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate Jaccard similarity over - * @return Output variable - */ - public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("Jaccard distance reduction", x, y); - SDVariable result = f().jaccardDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #lastIndex(String, SDVariable, Condition, int...) - */ - public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { - return lastIndex(null, in, condition, dimensions); - } - - /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each - * slice along the specified dimensions) - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { - return lastIndex(name, in, condition, false, dimensions); - } - - /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each - * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable, - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting - * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape: - * keepDims = true: [a,1,c]
- * keepDims = false: [a,c] - * - * @param name Name of the output variable - * @param in Input variable - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed - * @return Reduced array of rank (input rank - num dimensions) - */ - public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - SDVariable ret = f().lastIndex(in, condition, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #lastIndex(String, SDVariable, Condition, boolean, int...) - */ - public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return lastIndex(null, in, condition, keepDims, dimensions); - } - - /** - * List diff operation computes the difference between two 1d arrays, and also returns the indices - i.e., the positions - * where the output appears in the input X.
- * For inputs X and Y, listDiff returns everything in X but not in Y.
- * For example, if {@code X=[1,10,3,7,6]} and {@code Y=[10, 6]), then: - * output 0 (difference) = {@code [1,3,7]}
- * output 1 (indices) = {@code [0, 2, 3]}
- * @param x Input 1 - input values - * @param y Input 2 - values to remove - * @return 2 outputs - difference, and indices - */ - public SDVariable[] listDiff(SDVariable x, SDVariable y){ - return f().listdiff(x, y); - } - - /** - * Element-wise logarithm function (base e - natural logarithm): out = log(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable log(SDVariable x) { - return log(null, x); - } - - /** - * Element-wise logarithm function (base e - natural logarithm): out = log(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable log(String name, SDVariable x) { - validateNumerical("log", x); - SDVariable result = f().log(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise logarithm function (with specified base): out = log_{base}(x) - * - * @param in Input variable - * @param base Logarithm base - * @return Output variable - */ - public SDVariable log(SDVariable in, double base) { - return log(null, in, base); - } - - /** - * Element-wise logarithm function (with specified base): out = log_{base}(x) - * - * @param name Name of the output variable - * @param in Input variable - * @param base Logarithm base - * @return Output variable - */ - public SDVariable log(String name, SDVariable in, double base) { - validateNumerical("log", in); - SDVariable ret = f().log(in, base); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise natural logarithm function: out = log_e (1 + x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable log1p(SDVariable x) { - return log1p(null, x); - } - - /** - * Elementwise natural logarithm function: out = log_e (1 + x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable log1p(String name, SDVariable x) { - validateNumerical("log1p", x); - SDVariable result = f().log1p(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Log entropy reduction: log(-sum(x * log(x))) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable logEntropy(SDVariable in, int... dimensions) { - return logEntropy(null, in, dimensions); - } - - /** - * Log entropy reduction: log(-sum(x * log(x))) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { - validateNumerical("logEntropy reduction", in); - SDVariable ret = f().logEntropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log-sum-exp reduction (optionally along dimension). - * Computes log(sum(exp(x)) - * - * @param input Input variable - * @param dimensions Optional dimensions to reduce along - * @return Output variable - */ - public SDVariable logSumExp(SDVariable input, int... dimensions) { - return logSumExp(null, input, dimensions); - } - - /** - * Log-sum-exp reduction (optionally along dimension). - * Computes log(sum(exp(x)) - * - * @param name Name of the output variable - * @param input Input variable - * @param dimensions Optional dimensions to reduce along - * @return Output variable - */ - public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { - return logSumExp(name, input, false, dimensions); - } - - public SDVariable logSumExp(String name, SDVariable input, boolean keepDims, int... dimensions) { - validateNumerical("logSumExp reduction", input); - SDVariable ret = f().logSumExp(input, keepDims, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #manhattanDistance(String, SDVariable, SDVariable, int...) - */ - public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { - return manhattanDistance(sd.generateNewVarName(ManhattanDistance.OP_NAME, 0), x, y, dimensions); - } - - /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each - * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i]) - * - * @param name Name of the output variable - * @param x Input variable x - * @param y Input variable y - * @param dimensions Dimensions to calculate cosine similarity over - * @return Output variable - */ - public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { - validateNumerical("manhattan distance", x, y); - SDVariable result = f().manhattanDistance(x, y, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #matrixDeterminant(String, SDVariable) - */ - public SDVariable matrixDeterminant(SDVariable in) { - return matrixDeterminant(null, in); - } - - /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant. - * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each - * shape [m,m] sub-matrix. - * - * @param name Name of the output variable - * @param in Input - * @return Matrix determinant variable - */ - public SDVariable matrixDeterminant(String name, SDVariable in) { - validateNumerical("matrix determinant", in); - SDVariable ret = f().matrixDeterminant(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #matrixInverse(String, SDVariable) - */ - public SDVariable matrixInverse(SDVariable in) { - return matrixInverse(null, in); - } - - /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse. - * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each - * shape [m,m] sub-matrix. - * - * @param name Name of the output variable - * @param in Input - * @return Matrix inverse variable - */ - public SDVariable matrixInverse(String name, SDVariable in) { - validateFloatingPoint("matrix inverse", in); - SDVariable ret = f().matrixInverse(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge add function: merges an arbitrary number of equal shaped arrays using elementwise addition: - * out = sum_i in[i] - * - * @param x Input variables - * @return Output variable - */ - public SDVariable mergeAdd(SDVariable... x) { - return mergeAdd(null, x); - } - - /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition: - * out = sum_i in[i] - * - * @param name Name of the output variable - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAdd(String name, SDVariable... inputs) { - validateSameType("mergeAdd", true, inputs); - SDVariable ret = f().mergeAdd(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: - * out = mean_i in[i] - * - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAvg(SDVariable... inputs) { - return mergeAvg(null, inputs); - } - - /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation: - * out = mean_i in[i] - * - * @param name Name of the output variable - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeAvg(String name, SDVariable... inputs) { - validateSameType("mergeAvg", true, inputs); - SDVariable ret = f().mergeAvg(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: - * out = max_i in[i] - * - * @param x Input variables - * @return Output variable - */ - public SDVariable mergeMax(SDVariable... x) { - return mergeMax(null, x); - } - - /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation: - * out = max_i in[i] - * - * @param inputs Input variables - * @return Output variable - */ - public SDVariable mergeMax(String name, SDVariable... inputs) { - validateSameType("mergeMax", true, inputs); - SDVariable ret = f().mergeMax(inputs); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #meshgrid(List, SDVariable...) - */ - public SDVariable[] meshgrid(SDVariable... inputs) { - return meshgrid(null, inputs); - } - - /** - * Broadcast the 1D input variables onto an n-dimensional grid.
- * The resulting variable can be used for example for evaluating functions at all locations on a grid.
- * Example:
- *
-     * {@code input1 = [1, 2, 3]
-     * input2 = [4, 5, 6]
-     * SDVariable[] out = meshgrid(input1, input2)
-     * out[0]:
-     * [ 1, 2, 3]
-     * [ 1, 2, 3]
-     * [ 1, 2, 3]
-     *
-     * out[1]:
-     * [ 4, 4, 4]
-     * [ 5, 5, 5]
-     * [ 6, 6, 6]}
-     * 
- *
- * - * @param names List of names for the output variables. Must have exactly N names for N input arrays - * @param inputs N x 1D input variables - * @return an array of exactly N SDVariables (for N inputs), of rank N - */ - public SDVariable[] meshgrid(List names, SDVariable... inputs) { - return meshgrid(names, true, inputs); - } - - /** - * @see #meshgrid(List, SDVariable...) - */ - public SDVariable[] meshgrid(List names, boolean cartesian, SDVariable... inputs) { - Preconditions.checkState(names == null || names.size() == inputs.length, - "Got %s names but %s inputs", (names == null ? 0 : names.size()), inputs.length); - validateSameType("meshgrid", false, inputs); - SDVariable[] ret = f().meshgrid(cartesian, inputs); - for (int i = 0; i < ret.length; i++) { - ret[i] = updateVariableNameAndReference(ret[i], names == null ? null : names.get(i)); - } - return ret; - } - - /** - * @see #moments(String[], SDVariable, int...) - */ - public SDVariable[] moments(SDVariable input, int... axes) { - return moments(null, input, axes); - } - - /** - * Calculate the mean and (population) variance for the input variable, for the specified axis - * - * @param name Name of the output variables. Can be null; if non-null, must be length 2 - * @param input Input to calculate moments for - * @param axes Dimensions to perform calculation over - * @return Mean and variance variables - */ - public SDVariable[] moments(String[] name, SDVariable input, int... axes) { - validateNumerical("moments", input); - SDVariable[] res = f().moments(input, axes); - return sd.updateVariableNamesAndReferences(res, name); - } - - /** - * Elementwise negative operation: out = -x - * - * @param x Input variable - * @return Output variable - */ - public SDVariable neg(SDVariable x) { - return neg(null, x); - } - - /** - * Elementwise negative operation: out = -x - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable neg(String name, SDVariable x) { - validateNumerical("neg", x); - SDVariable result = f().neg(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #normalizeMoments(String[], SDVariable, SDVariable, SDVariable, double) - */ - public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { - return normalizeMoments(null, counts, means, variances, shift); - } - - /** - * Calculate the mean and variance from the sufficient statistics - * - * @param name Name of the output variables. Can be null; if non-null, must be length 2 - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics - * @param means Mean-value sufficient statistics: this is the SUM of all data values - * @param variances Variaance sufficient statistics: this is the squared sum of all data values - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) - * @return Output variables: mean and population variance - */ - public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances, - double shift) { - SDVariable[] res = f().normalizeMoments(counts, means, variances, shift); - return sd.updateVariableNamesAndReferences(res, name); - } - - /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable or(SDVariable x, SDVariable y) { - return or(null, x, y); - } - - /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable or(String name, SDVariable x, SDVariable y) { - validateBool("or", x, y); - SDVariable result = f().or(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise power function: out = x^value - * - * @param x Input variable - * @param value Power to raise each element to - * @return Output variable - */ - public SDVariable pow(SDVariable x, double value) { - return pow(null, x, value); - } - - /** - * Element-wise power function: out = x^value - * - * @param name Output variable name - * @param x Input variable - * @param value Power to raise each element to - * @return Output variable - */ - public SDVariable pow(String name, SDVariable x, double value) { - validateNumerical("pow", x); - SDVariable result = f().pow(x, value); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise (broadcastable) power function: out = x[i]^y[i] - * - * @param x Input variable - * @param y Power - * @return Output variable - */ - public SDVariable pow(SDVariable x, SDVariable y) { - return pow(null, x, y); - } - - /** - * Element-wise (broadcastable) power function: out = x[i]^y[i] - * - * @param name Output variable name - * @param x Input variable - * @param y Power - * @return Output variable - */ - public SDVariable pow(String name, SDVariable x, SDVariable y) { - validateNumerical("pow", x, y); - SDVariable result = f().pow(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] - * - * @param a Input variable - * @return Output variable - */ - public SDVariable reciprocal(SDVariable a) { - return reciprocal(null, a); - } - - /** - * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i] - * - * @param name Name of the output variable - * @param a Input variable - * @return Output variable - */ - public SDVariable reciprocal(String name, SDVariable a) { - validateNumerical("reciprocal", a); - SDVariable ret = f().reciprocal(a); - return updateVariableNameAndReference(ret, name); - } - - /** - * Elementwise round function: out = round(x). - * Rounds (up or down depending on value) to the nearest integer value. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable round(SDVariable x) { - return round(null, x); - } - - /** - * Element-wise round function: out = round(x). - * Rounds (up or down depending on value) to the nearest integer value. - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable round(String name, SDVariable x) { - validateFloatingPoint("round", x); - SDVariable result = f().round(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable rsqrt(SDVariable x) { - return rsqrt(null, x); - } - - /** - * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable rsqrt(String name, SDVariable x) { - validateNumerical("rsqrt", x); - SDVariable result = f().rsqrt(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #setDiag(String, SDVariable, SDVariable) - */ - public SDVariable setDiag(SDVariable in, SDVariable diag) { - return setDiag(null, in, diag); - } - - /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
- * - * @param name Name of the output variable - * @param in Input variable - * @param diag Diagonal - * @return Output variable - */ - public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { - SDVariable ret = f().setDiag(in, diag); - return updateVariableNameAndReference(ret, name); - } - - /** - * Shannon Entropy reduction: -sum(x * log2(x)) - * - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable - */ - public SDVariable shannonEntropy(SDVariable in, int... dimensions) { - return shannonEntropy(null, in, dimensions); - } - - /** - * Shannon Entropy reduction: -sum(x * log2(x)) - * - * @param name Name of the output variable - * @param in Input variable - * @param dimensions Dimensions to reduce on (null/empty for full array) - * @return Output variable: reduced array of rank (input rank - num dimensions) - */ - public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { - validateNumerical("shannon entropy reduction", in); - SDVariable ret = f().shannonEntropy(in, dimensions); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sign(SDVariable x) { - return sign(null, x); - } - - /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sign(String name, SDVariable x) { - validateNumerical("sign", x); - SDVariable result = f().sign(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise sine operation: out = sin(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sin(SDVariable x) { - return sin(null, x); - } - - /** - * Elementwise sine operation: out = sin(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sin(String name, SDVariable x) { - validateNumerical("sin", x); - SDVariable result = f().sin(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sinh(SDVariable x) { - return sinh(null, x); - } - - /** - * Elementwise sinh (hyperbolic sine) operation: out = sinh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sinh(String name, SDVariable x) { - validateNumerical("sinh", x); - SDVariable result = f().sinh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise square root function: out = sqrt(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable sqrt(SDVariable x) { - return sqrt(null, x); - } - - /** - * Element-wise square root function: out = sqrt(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable sqrt(String name, SDVariable x) { - validateNumerical("sqrt", x); - SDVariable result = f().sqrt(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise square function: out = x^2 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable square(SDVariable x) { - return square(null, x); - } - - /** - * Element-wise square function: out = x^2 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable square(String name, SDVariable x) { - validateNumerical("square", x); - SDVariable result = f().square(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise step function:
- * out(x) = 1 if x >= cutoff
- * out(x) = 0 otherwise
- * - * @param in Input variable - * @param cutoff Cutoff value for step function - * @return Output variable - */ - public SDVariable step(SDVariable in, double cutoff) { - return step(null, in, cutoff); - } - - /** - * Elementwise step function:
- * out(x) = 1 if x >= cutoff
- * out(x) = 0 otherwise
- * - * @param name Name of the output variable - * @param in Input variable - * @param cutoff Cutoff value for step function - * @return Output variable - */ - public SDVariable step(String name, SDVariable in, double cutoff) { - validateNumerical("step", in); - SDVariable ret = f().step(in, cutoff); - return updateVariableNameAndReference(ret, name); - } - - /** - * Standardize input variable along given axis - * - * @see #standardize(String, SDVariable, int...) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable standardize(SDVariable x, int... dimensions) { - return standardize(null, x, dimensions); - } - - /** - * Standardize input variable along given axis - *

- * out = (x - mean) / stdev - *

- * with mean and stdev being calculated along the given dimension. - * - *

- * For example: given x as a mini batch of the shape [numExamples, exampleLength]: - *

    - *
  • use dimension 1 too use the statistics (mean, stdev) for each example
  • - *
  • use dimension 0 if you want to use the statistics for each column across all examples
  • - *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples
  • - *
- * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable standardize(String name, SDVariable x, int... dimensions) { - validateNumerical("standardize", x); - SDVariable result = f().standardize(x, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise tangent operation: out = tan(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable tan(SDVariable x) { - return tan(null, x); - } - - /** - * Elementwise tangent operation: out = tan(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable tan(String name, SDVariable x) { - validateNumerical("tan", x); - SDVariable result = f().tan(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable tanh(SDVariable x) { - return tanh(null, x); - } - - /** - * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable tanh(String name, SDVariable x) { - validateNumerical("tanh", x); - SDVariable result = f().tanh(x); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #trace(String, SDVariable) - */ - public SDVariable trace(SDVariable in) { - return trace(null, in); - } - - /** - * Matrix trace operation - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
- * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:]) - * - * @param name Name of the output variable. May be null. - * @param in Input variable - * @return Trace - */ - public SDVariable trace(String name, SDVariable in) { - validateNumerical("trace", in); - SDVariable ret = f().trace(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable xor(SDVariable x, SDVariable y) { - return xor(null, x, y); - } - - /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise. - * - * @param name Name of the output variable - * @param x Input 1 - * @param y Input 2 - * @return Output SDVariable with values 0 and 1 based on where the condition is satisfied - */ - public SDVariable xor(String name, SDVariable x, SDVariable y) { - validateBool("xor", x, y); - SDVariable result = f().xor(x, y); - return updateVariableNameAndReference(result, name); - } - - /** - * Shift integer bits to the left, i.e. var << 4 - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { - validateInteger("shift_bits", x); - SDVariable result = f().shift(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Shift integer bits to the right, i.e. var >> 4 - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { - validateInteger("rshift_bits", x); - SDVariable result = f().rshift(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4) - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { - validateInteger("cyclic_shift_bits", x); - SDVariable result = f().rotl(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4) - * - * @param name Name of the output variable - * @param x Input 1 - * @return Output SDVariable with shifted bits - */ - public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { - validateInteger("cyclic_rshift_bits", x); - SDVariable result = f().rotr(x, shift); - return updateVariableNameAndReference(result, name); - } - - /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) - * - * @param input Input variable - * @return Reduced array of rank 0 (scalar) - */ - public SDVariable zeroFraction(SDVariable input) { - return zeroFraction(null, input); - } - - /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) - * - * @param name Name of the output variable - * @param input Input variable - * @return Reduced array of rank 0 (scalar) - */ - public SDVariable zeroFraction(String name, SDVariable input) { - validateNumerical("zeroFraction", input); - SDVariable res = f().zeroFraction(input); - return updateVariableNameAndReference(res, name); - } - - /** - * Compute the regularized incomplete beta integral - * - * @param name Name of the output variable - * @param a input array - * @param b input array - * @param x input array - * @return array - */ - public SDVariable betainc(String name,SDVariable a,SDVariable b,SDVariable x) { - SDVariable res = f().betainc(a,b,x); - return updateVariableNameAndReference(res, name); - } - - /** - * Copy a tensor setting everything outside a central band in each innermost matrix. - * - * @param name Name of the output variable - * @param input Rank k array - * @param minLower Number of subdiagonals to keep. - * @param maxUpper Number of superdiagonals to keep. - * @return Rank k array of the same shape as input. - */ - public SDVariable matrixBandPart(String name, SDVariable input, SDVariable minLower, SDVariable maxUpper) { - SDVariable res = f().matrixBandPart(input,minLower,maxUpper); - return updateVariableNameAndReference(res, name); - } - - /** - * Polygamma function - * - * @param name Name of the output variable - * @param n array - * @param x array - * @return array - */ - public SDVariable polygamma(String name, SDVariable n, SDVariable x) { - SDVariable res = f().polygamma(n,x); - return updateVariableNameAndReference(res, name); - } - - /** - * Rolls the elements of input - * - * @param name Name of the output variable - * @param input array - * @param shift number of places to shift elements - * @return array - */ - public SDVariable roll(String name, SDVariable input, int shift) { - SDVariable res = f().roll(input,shift); - return updateVariableNameAndReference(res, name); - } + public SDMath(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable abs(SDVariable x) { + SDValidation.validateNumerical("abs", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + } + + /** + * Elementwise absolute value operation: out = abs(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable abs(String name, SDVariable x) { + SDValidation.validateNumerical("abs", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acos(SDVariable x) { + SDValidation.validateNumerical("acos", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + } + + /** + * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acos(String name, SDVariable x) { + SDValidation.validateNumerical("acos", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acosh(SDVariable x) { + SDValidation.validateNumerical("acosh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + } + + /** + * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable acosh(String name, SDVariable x) { + SDValidation.validateNumerical("acosh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amean(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amean(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amean", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable amin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("amin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable and(SDVariable x, SDVariable y) { + SDValidation.validateBool("and", "x", x); + SDValidation.validateBool("and", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + } + + /** + * Boolean AND operation: elementwise (x != 0) && (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable and(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("and", "x", x); + SDValidation.validateBool("and", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asin(SDVariable x) { + SDValidation.validateNumerical("asin", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + } + + /** + * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asin(String name, SDVariable x) { + SDValidation.validateNumerical("asin", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asinh(SDVariable x) { + SDValidation.validateNumerical("asinh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + } + + /** + * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable asinh(String name, SDVariable x) { + SDValidation.validateNumerical("asinh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable asum(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + } + + /** + * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable asum(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("asum", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan(SDVariable x) { + SDValidation.validateNumerical("atan", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan(String name, SDVariable x) { + SDValidation.validateNumerical("atan", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan2(SDVariable y, SDVariable x) { + SDValidation.validateNumerical("atan2", "y", y); + SDValidation.validateNumerical("atan2", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + } + + /** + * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
+ * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * + * @param name name May be null. Name for the output variable + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atan2(String name, SDVariable y, SDVariable x) { + SDValidation.validateNumerical("atan2", "y", y); + SDValidation.validateNumerical("atan2", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atanh(SDVariable x) { + SDValidation.validateNumerical("atanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + } + + /** + * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable atanh(String name, SDVariable x) { + SDValidation.validateNumerical("atanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Bit shift operation
+ * + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShift(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShift", "x", x); + SDValidation.validateNumerical("bitShift", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x input (NUMERIC type) + * @param shift shift value (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShift", "x", x); + SDValidation.validateNumerical("bitShift", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Right bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRight", "x", x); + SDValidation.validateNumerical("bitShiftRight", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Right bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift shift argument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRight", "x", x); + SDValidation.validateNumerical("bitShiftRight", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cyclic bit shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotl(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotl", "x", x); + SDValidation.validateNumerical("bitShiftRotl", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Cyclic bit shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift shift argy=ument (NUMERIC type) + * @return output shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotl(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotl", "x", x); + SDValidation.validateNumerical("bitShiftRotl", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cyclic right shift operation
+ * + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotr(SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotr", "x", x); + SDValidation.validateNumerical("bitShiftRotr", "shift", shift); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + } + + /** + * Cyclic right shift operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param shift Shift argument (NUMERIC type) + * @return output Shifted output (NUMERIC type) + */ + public SDVariable bitShiftRotr(String name, SDVariable x, SDVariable shift) { + SDValidation.validateNumerical("bitShiftRotr", "x", x); + SDValidation.validateNumerical("bitShiftRotr", "shift", shift); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable ceil(SDVariable x) { + SDValidation.validateNumerical("ceil", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + } + + /** + * Element-wise ceiling function: out = ceil(x).
+ * Rounds each value up to the nearest integer value (if not already an integer)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable ceil(String name, SDVariable x) { + SDValidation.validateNumerical("ceil", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + } + + /** + * Clipping by L2 norm, optionally along dimension(s)
+ * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
+ * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
+ * to the corresponding l2Norm along the specified dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("clipByNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { + SDValidation.validateNumerical("clipByValue", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + } + + /** + * Element-wise clipping function:
+ * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
+ * out[i] = clipValueMin if in[i] < clipValueMin
+ * out[i] = clipValueMax if in[i] > clipValueMax
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValueMin Minimum value for clipping + * @param clipValueMax Maximum value for clipping + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, + double clipValueMax) { + SDValidation.validateNumerical("clipByValue", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
+ * [1, 0, 0]
+ * [0, 1, 1]
+ * [0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param dataType Data type + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + DataType dataType) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
+ * [1, 0, 0, 0]
+ * [0, 1, 1, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param numClasses Number of classes + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
+ * [1, 0, 0]
+ * [0, 3, 2]
+ * [0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + SDVariable weights) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights, + int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + } + + /** + * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
+ * which are represented as integer values.
+ * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
+ * [1, 0, 0, 0]
+ * [0, 3, 2, 0]
+ * [0, 0, 0, 0]
+ * [0, 0, 0, 0]
+ * + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param numClasses + * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) + */ + public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, + SDVariable weights, int numClasses) { + SDValidation.validateNumerical("confusionMatrix", "labels", labels); + SDValidation.validateNumerical("confusionMatrix", "pred", pred); + SDValidation.validateNumerical("confusionMatrix", "weights", weights); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cos(SDVariable x) { + SDValidation.validateNumerical("cos", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + } + + /** + * Elementwise cosine operation: out = cos(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cos(String name, SDVariable x) { + SDValidation.validateNumerical("cos", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosh(SDVariable x) { + SDValidation.validateNumerical("cosh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + } + + /** + * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosh(String name, SDVariable x) { + SDValidation.validateNumerical("cosh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineDistance", "x", x); + SDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Cosine distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = 1.0 - cosineSimilarity(x,y)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineDistance", "x", x); + SDValidation.validateNumerical("cosineDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineSimilarity", "x", x); + SDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + } + + /** + * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
+ * along the specified dimensions:
+ * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("cosineSimilarity", "x", x); + SDValidation.validateNumerical("cosineSimilarity", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countNonZero(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + } + + /** + * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countNonZero(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countNonZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countZero(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + } + + /** + * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable countZero(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("countZero", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public SDVariable cross(SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + } + + /** + * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
+ * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * + * @param name name May be null. Name for the output variable + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) + * @return output Element-wise cross product (NUMERIC type) + */ + public SDVariable cross(String name, SDVariable a, SDVariable b) { + SDValidation.validateNumerical("cross", "a", a); + SDValidation.validateNumerical("cross", "b", b); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cube(SDVariable x) { + SDValidation.validateNumerical("cube", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + } + + /** + * Element-wise cube function: out = x^3
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cube(String name, SDVariable x) { + SDValidation.validateNumerical("cube", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable diag(SDVariable x) { + SDValidation.validateNumerical("diag", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + } + + /** + * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
+ * For example, if input = [1,2,3], then output is given by:
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ *
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
+ * i.e., for input rank R, output has rank 2R
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable diag(String name, SDVariable x) { + SDValidation.validateNumerical("diag", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public SDVariable diagPart(SDVariable x) { + SDValidation.validateNumerical("diagPart", "x", x); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + } + + /** + * Extract the diagonal part from the input array.
+ * If input is
+ * [ 1, 0, 0]
+ * [ 0, 2, 0]
+ * [ 0, 0, 3]
+ * then output is [1, 2, 3].
+ * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Diagonal part of the input (NUMERIC type) + */ + public SDVariable diagPart(String name, SDVariable x) { + SDValidation.validateNumerical("diagPart", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable entropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + } + + /** + * Entropy reduction: -sum(x * log(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable entropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("entropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erf(SDVariable x) { + SDValidation.validateNumerical("erf", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + } + + /** + * Element-wise Gaussian error function - out = erf(in)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erf(String name, SDVariable x) { + SDValidation.validateNumerical("erf", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erfc(SDVariable x) { + SDValidation.validateNumerical("erfc", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + } + + /** + * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable erfc(String name, SDVariable x) { + SDValidation.validateNumerical("erfc", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("euclideanDistance", "x", x); + SDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("euclideanDistance", "x", x); + SDValidation.validateNumerical("euclideanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable exp(SDVariable x) { + SDValidation.validateNumerical("exp", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + } + + /** + * Elementwise exponent function: out = exp(x) = 2.71828...^x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable exp(String name, SDVariable x) { + SDValidation.validateNumerical("exp", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable expm1(SDVariable x) { + SDValidation.validateNumerical("expm1", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + } + + /** + * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable expm1(String name, SDVariable x) { + SDValidation.validateNumerical("expm1", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(int rows) { + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + } + + /** + * Generate an identity matrix with the specified number of rows and columns.
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, int rows) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public SDVariable eye(int rows, int cols) { + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + } + + /** + * As per eye(String, int, int, DataType) but with the default datatype, Eye.DEFAULT_DTYPE
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @return output (NUMERIC type) + */ + public SDVariable eye(String name, int rows, int cols) { + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + } + + /** + * Generate an identity matrix with the specified number of rows and columns
+ * Example:
+ *

+ * {@code INDArray eye = eye(3,2)
+ * eye:
+ * [ 1, 0]
+ * [ 0, 1]
+ * [ 0, 0]}
+ *

+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(SDVariable rows, SDVariable cols) { + SDValidation.validateInteger("eye", "rows", rows); + SDValidation.validateInteger("eye", "cols", cols); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + } + + /** + * As per eye(int, int) bit with the number of rows/columns specified as scalar INDArrays
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows (INT type) + * @param cols Number of columns (INT type) + * @return output Identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, SDVariable rows, SDVariable cols) { + SDValidation.validateInteger("eye", "rows", rows); + SDValidation.validateInteger("eye", "cols", cols); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public SDVariable eye(SDVariable rows) { + SDValidation.validateInteger("eye", "rows", rows); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + } + + /** + * As per eye(String, int) but with the number of rows specified as a scalar INDArray
+ * + * @param name name May be null. Name for the output variable + * @param rows Number of rows (INT type) + * @return output SDVaribable identity matrix (NUMERIC type) + */ + public SDVariable eye(String name, SDVariable rows) { + SDValidation.validateInteger("eye", "rows", rows); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + } + + /** + * First index reduction operation.
+ * Returns a variable that contains the index of the first element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("firstIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floor(SDVariable x) { + SDValidation.validateNumerical("floor", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + } + + /** + * Element-wise floor function: out = floor(x).
+ * Rounds each value down to the nearest integer value (if not already an integer)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floor(String name, SDVariable x) { + SDValidation.validateNumerical("floor", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("hammingDistance", "x", x); + SDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Hamming distance reduction operation. The output contains the cosine distance for each
+ * tensor/subset along the specified dimensions:
+ * out = count( x[i] != y[i] )
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("hammingDistance", "x", x); + SDValidation.validateNumerical("hammingDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Index of the max absolute value: argmax(abs(in))
+ * see argmax(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamax", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + } + + /** + * Index of the min absolute value: argmin(abs(in))
+ * see argmin(String, INDArray, boolean, int...)
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { + SDValidation.validateNumerical("iamin", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isFinite(SDVariable x) { + SDValidation.validateNumerical("isFinite", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + } + + /** + * Is finite operation: elementwise isFinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isFinite(String name, SDVariable x) { + SDValidation.validateNumerical("isFinite", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isInfinite(SDVariable x) { + SDValidation.validateNumerical("isInfinite", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + } + + /** + * Is infinite operation: elementwise isInfinite(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isInfinite(String name, SDVariable x) { + SDValidation.validateNumerical("isInfinite", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isMax(SDVariable x) { + SDValidation.validateNumerical("isMax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + } + + /** + * Is maximum operation: elementwise x == max(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isMax(String name, SDVariable x) { + SDValidation.validateNumerical("isMax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isNaN(SDVariable x) { + SDValidation.validateNumerical("isNaN", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + } + + /** + * Is Not a Number operation: elementwise isNaN(x)
+ * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
+ * value 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable isNaN(String name, SDVariable x) { + SDValidation.validateNumerical("isNaN", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isNonDecreasing(SDVariable x) { + SDValidation.validateNumerical("isNonDecreasing", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + } + + /** + * Is the array non decreasing?
+ * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isNonDecreasing(String name, SDVariable x) { + SDValidation.validateNumerical("isNonDecreasing", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isStrictlyIncreasing(SDVariable x) { + SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + } + + /** + * Is the array strictly increasing?
+ * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
+ * in 'c' (row major) order
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + */ + public SDVariable isStrictlyIncreasing(String name, SDVariable x) { + SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("jaccardDistance", "x", x); + SDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
+ * tensor along the specified dimensions.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("jaccardDistance", "x", x); + SDValidation.validateNumerical("jaccardDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + } + + /** + * Last index reduction operation.
+ * Returns a variable that contains the index of the last element that matches the specified condition (for each
+ * slice along the specified dimensions)
+ * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, + int... dimensions) { + SDValidation.validateNumerical("lastIndex", "in", in); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculates difference between inputs X and Y.
+ * + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public SDVariable[] listDiff(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("listDiff", "x", x); + SDValidation.validateNumerical("listDiff", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + } + + /** + * Calculates difference between inputs X and Y.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) + */ + public SDVariable[] listDiff(String[] names, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("listDiff", "x", x); + SDValidation.validateNumerical("listDiff", "y", y); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(SDVariable x) { + SDValidation.validateNumerical("log", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + } + + /** + * Element-wise logarithm function (base e - natural logarithm): out = log(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(String name, SDVariable x) { + SDValidation.validateNumerical("log", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param x Input variable (NUMERIC type) + * @param base Logarithm base + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(SDVariable x, double base) { + SDValidation.validateNumerical("log", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + } + + /** + * Element-wise logarithm function (with specified base): out = log_{base}(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param base Logarithm base + * @return output Output variable (NUMERIC type) + */ + public SDVariable log(String name, SDVariable x, double base) { + SDValidation.validateNumerical("log", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log1p(SDVariable x) { + SDValidation.validateNumerical("log1p", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + } + + /** + * Elementwise natural logarithm function: out = log_e (1 + x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable log1p(String name, SDVariable x) { + SDValidation.validateNumerical("log1p", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable logEntropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + } + + /** + * Log entropy reduction: log(-sum(x * log(x)))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("logEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSumExp(SDVariable input, int... dimensions) { + SDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + } + + /** + * Log-sum-exp reduction (optionally along dimension).
+ * Computes log(sum(exp(x))
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { + SDValidation.validateNumerical("logSumExp", "input", input); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("manhattanDistance", "x", x); + SDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + } + + /** + * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
+ * tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i]-y[i])
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { + SDValidation.validateNumerical("manhattanDistance", "x", x); + SDValidation.validateNumerical("manhattanDistance", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public SDVariable matrixDeterminant(SDVariable in) { + SDValidation.validateNumerical("matrixDeterminant", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + } + + /** + * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
+ * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @return output Matrix determinant variable (NUMERIC type) + */ + public SDVariable matrixDeterminant(String name, SDVariable in) { + SDValidation.validateNumerical("matrixDeterminant", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public SDVariable matrixInverse(SDVariable in) { + SDValidation.validateNumerical("matrixInverse", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + } + + /** + * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
+ * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
+ * shape [m,m] sub-matrix.
+ * + * @param name name May be null. Name for the output variable + * @param in Input (NUMERIC type) + * @return output Matrix inverse variable (NUMERIC type) + */ + public SDVariable matrixInverse(String name, SDVariable in) { + SDValidation.validateNumerical("matrixInverse", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAdd(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + } + + /** + * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
+ * out = sum_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAdd(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAdd", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAvg(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + } + + /** + * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
+ * out = mean_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeAvg(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeAvg", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeMax(SDVariable[] inputs) { + SDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + } + + /** + * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
+ * out = max_i in[i]
+ * + * @param name name May be null. Name for the output variable + * @param inputs Input variables (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mergeMax(String name, SDVariable[] inputs) { + SDValidation.validateNumerical("mergeMax", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param inputs (NUMERIC type) + * @param cartesian + */ + public SDVariable[] meshgrid(SDVariable[] inputs, boolean cartesian) { + SDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + } + + /** + * Broadcasts parameters for evaluation on an N-D grid.
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param inputs (NUMERIC type) + * @param cartesian + */ + public SDVariable[] meshgrid(String[] names, SDVariable[] inputs, boolean cartesian) { + SDValidation.validateNumerical("meshgrid", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + */ + public SDVariable[] moments(SDVariable input, int... axes) { + SDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + } + + /** + * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param input Input to calculate moments for (NUMERIC type) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + */ + public SDVariable[] moments(String[] names, SDVariable input, int... axes) { + SDValidation.validateNumerical("moments", "input", input); + Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable neg(SDVariable x) { + SDValidation.validateNumerical("neg", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + } + + /** + * Elementwise negative operation: out = -x
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable neg(String name, SDVariable x) { + SDValidation.validateNumerical("neg", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + */ + public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, + double shift) { + SDValidation.validateNumerical("normalizeMoments", "counts", counts); + SDValidation.validateNumerical("normalizeMoments", "means", means); + SDValidation.validateNumerical("normalizeMoments", "variances", variances); + return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + } + + /** + * Calculate the mean and variance from the sufficient statistics
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + */ + public SDVariable[] normalizeMoments(String[] names, SDVariable counts, SDVariable means, + SDVariable variances, double shift) { + SDValidation.validateNumerical("normalizeMoments", "counts", counts); + SDValidation.validateNumerical("normalizeMoments", "means", means); + SDValidation.validateNumerical("normalizeMoments", "variances", variances); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable or(SDVariable x, SDVariable y) { + SDValidation.validateBool("or", "x", x); + SDValidation.validateBool("or", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + } + + /** + * Boolean OR operation: elementwise (x != 0) || (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable or(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("or", "x", x); + SDValidation.validateBool("or", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(SDVariable x, double value) { + SDValidation.validateNumerical("pow", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + } + + /** + * Element-wise power function: out = x^value
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(String name, SDVariable x, double value) { + SDValidation.validateNumerical("pow", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("pow", "x", x); + SDValidation.validateNumerical("pow", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + } + + /** + * Element-wise (broadcastable) power function: out = x[i]^y[i]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable pow(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("pow", "x", x); + SDValidation.validateNumerical("pow", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reciprocal(SDVariable x) { + SDValidation.validateNumerical("reciprocal", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + } + + /** + * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reciprocal(String name, SDVariable x) { + SDValidation.validateNumerical("reciprocal", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable round(SDVariable x) { + SDValidation.validateNumerical("round", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + } + + /** + * Element-wise round function: out = round(x).
+ * Rounds (up or down depending on value) to the nearest integer value.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable round(String name, SDVariable x) { + SDValidation.validateNumerical("round", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsqrt(SDVariable x) { + SDValidation.validateNumerical("rsqrt", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + } + + /** + * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsqrt(String name, SDVariable x) { + SDValidation.validateNumerical("rsqrt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable setDiag(SDVariable in, SDVariable diag) { + SDValidation.validateNumerical("setDiag", "in", in); + SDValidation.validateNumerical("setDiag", "diag", diag); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + } + + /** + * Set the diagonal value to the specified values
+ * If input is
+ * [ a, b, c]
+ * [ d, e, f]
+ * [ g, h, i]
+ * and diag = [ 1, 2, 3] then output is
+ * [ 1, b, c]
+ * [ d, 2, f]
+ * [ g, h, 3]
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param diag Diagonal (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { + SDValidation.validateNumerical("setDiag", "in", in); + SDValidation.validateNumerical("setDiag", "diag", diag); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable shannonEntropy(SDVariable in, int... dimensions) { + SDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + } + + /** + * Shannon Entropy reduction: -sum(x * log2(x))
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { + SDValidation.validateNumerical("shannonEntropy", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sign(SDVariable x) { + SDValidation.validateNumerical("sign", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + } + + /** + * Element-wise sign (signum) function:
+ * out = -1 if in < 0
+ * out = 0 if in = 0
+ * out = 1 if in > 0
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sign(String name, SDVariable x) { + SDValidation.validateNumerical("sign", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sin(SDVariable x) { + SDValidation.validateNumerical("sin", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + } + + /** + * Elementwise sine operation: out = sin(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sin(String name, SDVariable x) { + SDValidation.validateNumerical("sin", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sinh(SDVariable x) { + SDValidation.validateNumerical("sinh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + } + + /** + * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sinh(String name, SDVariable x) { + SDValidation.validateNumerical("sinh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sqrt(SDVariable x) { + SDValidation.validateNumerical("sqrt", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + } + + /** + * Element-wise square root function: out = sqrt(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sqrt(String name, SDVariable x) { + SDValidation.validateNumerical("sqrt", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable square(SDVariable x) { + SDValidation.validateNumerical("square", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + } + + /** + * Element-wise square function: out = x^2
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable square(String name, SDVariable x) { + SDValidation.validateNumerical("square", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable standardize(SDVariable x, int... dimensions) { + SDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + } + + /** + * Standardize input variable along given axis
+ *


+ * out = (x - mean) / stdev
+ *


+ * with mean and stdev being calculated along the given dimension.
+ *


+ * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
+ *


    + *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • + *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • + *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *

+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable standardize(String name, SDVariable x, int... dimensions) { + SDValidation.validateNumerical("standardize", "x", x); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable step(SDVariable x, double value) { + SDValidation.validateNumerical("step", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + } + + /** + * Elementwise step function:
+ * out(x) = 1 if x >= cutoff
+ * out(x) = 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable step(String name, SDVariable x, double value) { + SDValidation.validateNumerical("step", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tan(SDVariable x) { + SDValidation.validateNumerical("tan", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + } + + /** + * Elementwise tangent operation: out = tan(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tan(String name, SDVariable x) { + SDValidation.validateNumerical("tan", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(String name, SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public SDVariable trace(SDVariable in) { + SDValidation.validateNumerical("trace", "in", in); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + } + + /** + * Matrix trace operation
+ * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
+ * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
+ * + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @return output Trace (NUMERIC type) + */ + public SDVariable trace(String name, SDVariable in) { + SDValidation.validateNumerical("trace", "in", in); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable xor(SDVariable x, SDVariable y) { + SDValidation.validateBool("xor", "x", x); + SDValidation.validateBool("xor", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + } + + /** + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * + * @param name name May be null. Name for the output variable + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + */ + public SDVariable xor(String name, SDVariable x, SDVariable y) { + SDValidation.validateBool("xor", "x", x); + SDValidation.validateBool("xor", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public SDVariable zeroFraction(SDVariable input) { + SDValidation.validateNumerical("zeroFraction", "input", input); + return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + } + + /** + * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @return output Reduced array of rank 0 (scalar) (NUMERIC type) + */ + public SDVariable zeroFraction(String name, SDVariable input) { + SDValidation.validateNumerical("zeroFraction", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 63aab3f33..7b18c3614 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,1054 +14,1139 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import lombok.NonNull; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.base.Preconditions; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateFloatingPoint; - -/** - * SameDiff general neural network operations
- * Accessible via {@link SameDiff#math()}
- * See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.
- * See also {@link SDRNN} (accessible via {@link SameDiff#rnn()} for recurrent neural network ops.
- * - * @author Alex Black - */ public class SDNN extends SDOps { - public SDNN(SameDiff sameDiff) { - super(sameDiff); - } - - /** - * Batch norm operation. - * - * @see #batchNorm(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, double, int...) - */ - public SDVariable batchNorm(SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, double epsilon, int... axis) { - return batchNorm(null, input, mean, variance, gamma, beta, true, true, epsilon, axis); - } - - /** - * Batch normalization with optional application of gamma/beta args. - * See {@link #batchNorm(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, double, int...)} - */ - public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon, int... axis) { - validateFloatingPoint("batchNorm", "input", input); - validateFloatingPoint("batchNorm", "mean", mean); - validateFloatingPoint("batchNorm", "variance", variance); - validateFloatingPoint("batchNorm", "gamma", gamma); - validateFloatingPoint("batchNorm", "beta", beta); - SDVariable res = f().batchNorm(input, mean, variance, gamma, beta, applyGamma, applyBeta, epsilon, axis); - return updateVariableNameAndReference(res, name); - } - - /** - * Neural network batch normalization operation.
- * For details, see https://arxiv.org/abs/1502.03167 - * - * @param name Name of the output variable - * @param input Input variable. - * @param mean Mean value. For 1d axis, this should match input.size(axis) - * @param variance Variance value. For 1d axis, this should match input.size(axis) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) - * @param beta Beta value. For 1d axis, this should match input.size(axis) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations.
- * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC
- * For 1d/RNN activations: 1 for NCW format, 2 for NWC - * @return Output variable for batch normalization - */ - public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, double epsilon, int... axis) { - return batchNorm(name, input, mean, variance, gamma, beta, true, true, epsilon, axis); - } - - /** - * @see #biasAdd(String, SDVariable, SDVariable, boolean) - */ - public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { - return biasAdd(null, input, bias, nchw); - } - - /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector - * - * @param name Name of the output variable - * @param input 4d input variable - * @param bias 1d bias - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs - * @return Output variable - */ - public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { - validateFloatingPoint("biasAdd", "input", input); - validateFloatingPoint("biasAdd", "bias", bias); - SDVariable ret = f().biasAdd(input, bias, nchw); - return updateVariableNameAndReference(ret, name); - } - - /** - * @param input Input - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) - * @return - */ - public SDVariable dropout(SDVariable input, double inputRetainProbability) { - return dropout(null, input, inputRetainProbability); - } - - /** - * @param input Input - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) - * @return - */ - public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { - validateFloatingPoint("dropout", input); - SDVariable res = f().dropout(input, inputRetainProbability); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise exponential linear unit (ELU) function:
- * out = x if x > 0
- * out = a * (exp(x) - 1) if x <= 0
- * with constant a = 1.0 - *

- * See: https://arxiv.org/abs/1511.07289 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable elu(SDVariable x) { - return elu(null, x); - } - - /** - * Element-wise exponential linear unit (ELU) function:
- * out = x if x > 0
- * out = a * (exp(x) - 1) if x <= 0
- * with constant a = 1.0 - *

- * See: https://arxiv.org/abs/1511.07289 - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable elu(String name, SDVariable x) { - validateFloatingPoint("elu", x); - SDVariable result = f().elu(x); - return updateVariableNameAndReference(result, name); - } - - /** - * GELU activation function - Gaussian Error Linear Units
- * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415 - * This method uses the sigmoid approximation - * - * @param x Input - * @return Output variable - GELU applied to the input - */ - public SDVariable gelu(SDVariable x) { - return gelu(null, x); - } - - /** - * GELU activation function - Gaussian Error Linear Units
- * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415 - * This method uses the sigmoid approximation - * - * @param name Name of the output variable. May be null. - * @param x Input - * @return Output variable - GELU applied to the input - */ - public SDVariable gelu(String name, SDVariable x) { - validateFloatingPoint("gelu", x); - SDVariable ret = f().gelu(x, false); //Defaults to si - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise hard sigmoid function:
- * out[i] = 0 if in[i] <= -2.5
- * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
- * out[i] = 1 if in[i] >= 2.5
- * - * @param in Input variable - * @return Output variable - */ - public SDVariable hardSigmoid(SDVariable in) { - return hardSigmoid(null, in); - } - - /** - * Element-wise hard sigmoid function:
- * out[i] = 0 if in[i] <= -2.5
- * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
- * out[i] = 1 if in[i] >= 2.5
- * - * @param name Name of the output variable - * @param in Input variable - * @return Output variable - */ - public SDVariable hardSigmoid(String name, SDVariable in) { - validateFloatingPoint("hard sigmoid", in); - SDVariable ret = f().hardSigmoid(in); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise hard tanh function:
- * out[i] = -1 if in[i] <= -1
- * out[1] = in[i] if -1 < in[i] < 1
- * out[i] = 1 if in[i] >= 1
- * - * @param in Input variable - * @return Output variable - */ - public SDVariable hardTanh(SDVariable in) { - return hardTanh(null, in); - } - - /** - * Element-wise hard tanh function:
- * out[i] = -1 if in[i] <= -1
- * out[1] = in[i] if -1 < in[i] < 1
- * out[i] = 1 if in[i] >= 1
- * - * @param name Output variable name - * @param in Input variable - * @return Output variable - */ - public SDVariable hardTanh(String name, SDVariable in) { - validateFloatingPoint("hard Tanh", in); - SDVariable result = f().hardTanh(in); - return updateVariableNameAndReference(result, name); - } - - /** - * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} - * - * @param x Input - * @return Output variable - */ - public SDVariable hardTanhDerivative(SDVariable x) { - return hardTanhDerivative(null, x); - } - - /** - * Derivative (dOut/dIn) of the element-wise hard Tanh function - {@link #hardTanh(SDVariable)} - * - * @param name Output variable name - * @param x Input - * @return Output variable - */ - public SDVariable hardTanhDerivative(String name, SDVariable x) { - validateFloatingPoint("hard Tanh derivative", x); - SDVariable result = f().hardTanhDerivative(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise leaky ReLU function:
- * out = x if x >= 0.0
- * out = alpha * x if x < cutoff
- * Alpha value is most commonly set to 0.01 - * - * @param x Input variable - * @param alpha Cutoff - usually 0.0 - * @return Output variable - */ - public SDVariable leakyRelu(SDVariable x, double alpha) { - return leakyRelu(null, x, alpha); - } - - /** - * Element-wise leaky ReLU function:
- * out = x if x >= 0.0
- * out = alpha * x if x < cutoff
- * Alpha value is most commonly set to 0.01 - * - * @param x Input variable - * @param alpha Cutoff - usually 0.0 - * @return Output variable - */ - public SDVariable leakyRelu(String name, SDVariable x, double alpha) { - validateFloatingPoint("leaky ReLU", x); - SDVariable result = f().leakyRelu(x, alpha); - return updateVariableNameAndReference(result, name); - } - - /** - * Leaky ReLU derivative: dOut/dIn given input.
- * See {@link #leakyRelu(String, SDVariable, double)} - * - * @param x Input variable - * @param alpha Alpha value - * @return Output variable - */ - public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { - validateFloatingPoint("leaky ReLU derivative", x); - SDVariable result = f().leakyReluDerivative(x, alpha); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #linear(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { - return linear(null, input, weights, bias); - } - - /** - * Linear layer operation: out = mmul(in,w) + bias
- * Note that bias array is optional - * - * @param name Name of the output variable - * @param input Input data - * @param weights Weights variable - * @param bias Optional bias variable (may be null) - * @return Output variable - */ - public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { - validateFloatingPoint("linear", "input", input); - validateFloatingPoint("linear", "weights", weights); - validateFloatingPoint("linear", "bias", bias); - SDVariable res = f().xwPlusB(input, weights, bias); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) - * - * @param x Input Variable - * @return Output variable - */ - public SDVariable logSigmoid(SDVariable x) { - return logSigmoid(null, x); - } - - /** - * Element-wise sigmoid function: out[i] = log(sigmoid(in[i])) - * - * @param name Name of the output variable - * @param x Input Variable - * @return Output variable - */ - public SDVariable logSigmoid(String name, SDVariable x) { - validateFloatingPoint("log sigmoid", x); - SDVariable ret = f().logSigmoid(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(SDVariable x) { - return logSoftmax(null, x); - } - - /** - * Log softmax activation - * - * @param name Variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(String name, SDVariable x) { - validateFloatingPoint("log softmax", x); - SDVariable ret = f().logSoftmax(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Log softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(SDVariable x, int dimension) { - return logSoftmax(null, x, dimension); - } - - /** - * Log softmax activation - * - * @param name Variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable logSoftmax(String name, SDVariable x, int dimension) { - validateFloatingPoint("log softmax", x); - SDVariable ret = f().logSoftmax(x, dimension); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise rectified linear function with specified cutoff:
- * out[i] = in[i] if in[i] >= cutoff - * out[i] = 0 otherwise - * - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu(SDVariable x, double cutoff) { - return relu(null, x, cutoff); - } - - /** - * Element-wise rectified linear function with specified cutoff:
- * out[i] = in[i] if in[i] >= cutoff - * out[i] = 0 otherwise - * - * @param name Output variable name - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu(String name, SDVariable x, double cutoff) { - validateFloatingPoint("ReLU", x); - SDVariable result = f().relu(x, cutoff); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise "rectified linear 6" function with specified cutoff:
- * out[i] = min(max(in, cutoff), 6) - * - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu6(SDVariable x, double cutoff) { - return relu6(null, x, cutoff); - } - - /** - * Element-wise "rectified linear 6" function with specified cutoff:
- * out[i] = min(max(in, cutoff), 6) - * - * @param name Output variable name - * @param x Input variable - * @param cutoff Cutoff value. Usually 0 - * @return Output variable - */ - public SDVariable relu6(String name, SDVariable x, double cutoff) { - validateFloatingPoint("ReLU6", x); - SDVariable result = f().relu6(x, cutoff); - return updateVariableNameAndReference(result, name); - } - - /** - * @see #reluLayer(String, SDVariable, SDVariable, SDVariable) - */ - public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { - return reluLayer(null, input, weights, bias); - } - - /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
- * Note that bias array is optional - * - * @param name Name of the output variable - * @param input Input data - * @param weights Weights variable - * @param bias Optional bias variable (may be null) - * @return Output variable - */ - public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { - validateFloatingPoint("reluLayer", "input", input); - validateFloatingPoint("reluLayer", "weights", weights); - validateFloatingPoint("reluLayer", "bias", bias); - SDVariable res = f().reluLayer(input, weights, bias); - return updateVariableNameAndReference(res, name); - } - - /** - * See {@link #prelu(String, SDVariable, SDVariable, int...)}. - */ - public SDVariable prelu(@NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){ - return f().prelu(input, alpha, sharedAxes); - } - - /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
- * out[i] = in[i] if in[i] >= 0
- * out[i] = in[i] * alpha[i] otherwise
- * - * sharedAxes allows you to share learnable parameters along axes. - * For example, if the input has shape [batchSize, channels, height, width] - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an - * alpha with shape [channels]. - * - * @param name Name of the output variable - * @param input Input data - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. - * @param sharedAxes Which axes to share cutoff parameters along. - * @return Output variable - */ - public SDVariable prelu(String name, @NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){ - SDVariable res = f().prelu(input, alpha, sharedAxes); - return updateVariableNameAndReference(res, name); - } - - /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks - *
- * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
- * Uses default lcale and alpha values. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable selu(SDVariable x) { - return selu(null, x); - } - - /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks - *
- * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
- * Uses default lcale and alpha values. - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable selu(String name, SDVariable x) { - validateFloatingPoint("selu", x); - SDVariable ret = f().selu(x); - return updateVariableNameAndReference(ret, name); - } - - /** - * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) - * - * @param x Input Variable - * @return Output variable - */ - public SDVariable sigmoid(SDVariable x) { - return sigmoid(null, x); - } - - /** - * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i])) - * - * @param name Output variable name - * @param x Input Variable - * @return Output variable - */ - public SDVariable sigmoid(String name, SDVariable x) { - validateFloatingPoint("sigmoid", x); - SDVariable result = f().sigmoid(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut - * - * @param x Input Variable - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input - * @return Output variable - */ - public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { - return sigmoidDerivative(null, x, wrt); - } - - /** - * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut - * - * @param name Output variable name - * @param x Input Variable - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input - * @return Output variable - */ - public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { - validateFloatingPoint("sigmoidDerivative", x); - SDVariable result = f().sigmoidDerivative(x, wrt); - return updateVariableNameAndReference(result, name); - } - - /** - * Softmax activation on dimension 1. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(SDVariable x) { - return softmax(null, x); - } - - /** - * Softmax activation on dimension 1. - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(String name, SDVariable x) { - validateFloatingPoint("softmax", x); - SDVariable result = f().softmax(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(SDVariable x, int dimension) { - return softmax(null, x, dimension); - } - - /** - * Softmax activation - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softmax(String name, SDVariable x, int dimension) { - validateFloatingPoint("softmax", x); - SDVariable result = f().softmax(x, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * @param x - * @return - */ - public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt) { - return softmaxDerivative(name, x, wrt, null); - } - - public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, Integer dimension) { - validateFloatingPoint("softmaxDerivative", x); - SDVariable result = f().softmaxDerivative(x, wrt, dimension); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise softplus function: out = log(exp(x) + 1) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softplus(SDVariable x) { - return softplus(null, x); - } - - /** - * Element-wise softplus function: out = log(exp(x) + 1) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable softplus(String name, SDVariable x) { - validateFloatingPoint("softplus", x); - SDVariable result = f().softplus(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise softsign function: out = x / (abs(x) + 1) - * - * @param x Input variable - * @return Output variable - */ - public SDVariable softsign(SDVariable x) { - return softsign(null, x); - } - - /** - * Element-wise softsign function: out = x / (abs(x) + 1) - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable softsign(String name, SDVariable x) { - validateFloatingPoint("softsign", x); - SDVariable result = f().softsign(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} - * - * @param x Input variable - * @return Output varible - */ - public SDVariable softsignDerivative(SDVariable x) { - return softsignDerivative(null, x); - } - - /** - * Element-wise derivative (dOut/dIn) of the softsign function {@link #softsign(SDVariable)} - * - * @param name Output variable name - * @param x Input variable - * @return Output varible - */ - public SDVariable softsignDerivative(String name, SDVariable x) { - validateFloatingPoint("softsignDerivative", x); - SDVariable result = f().softsignDerivative(x); - return updateVariableNameAndReference(result, name); - } - - /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
- * See: https://arxiv.org/abs/1710.05941 - * - * @param x Input variable - * @return Output variable - */ - public SDVariable swish(SDVariable x) { - return swish(null, x); - } - - /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
- * See: https://arxiv.org/abs/1710.05941 - * - * @param name Name of the output variable - * @param x Input variable - * @return Output variable - */ - public SDVariable swish(String name, SDVariable x) { - validateFloatingPoint("swish", x); - SDVariable ret = f().swish(x); - return updateVariableNameAndReference(ret, name); - } - - public SDVariable tanh(String name, SDVariable x) { - return sd.math().tanh(name, x); - } - - public SDVariable tanh(SDVariable x) { - return sd.math().tanh(x); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) + bias - * - * @return Output variable - */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - return layerNorm(null, input, gain, bias, channelsFirst, dimensions); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) + bias - * - * @param name Name of the output variable - * @param input Input variable - * @param gain gain - * @param bias bias - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs - * @return Output variable - */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - validateFloatingPoint("layerNorm", "input", input); - validateFloatingPoint("layerNorm", "gain", gain); - validateFloatingPoint("layerNorm", "bias", bias); - SDVariable result = f().layerNorm(input, gain, bias, channelsFirst, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * Apply Layer Normalization without bias - * - * y = gain * standardize(x) - * - * @return Output variable - */ - public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - return layerNorm((String)null, input, gain, channelsFirst, dimensions); - } - - /** - * Apply Layer Normalization - * - * y = gain * standardize(x) - * - * @param name Name of the output variable - * @param input Input variable - * @param gain gain - * @return Output variable - */ - public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - validateFloatingPoint("layerNorm", "input", input); - validateFloatingPoint("layerNorm", "gain", gain); - SDVariable result = f().layerNorm(input, gain, channelsFirst, dimensions); - return updateVariableNameAndReference(result, name); - } - - /** - * See {@link #pad(SDVariable, SDVariable, double)} - */ - public SDVariable pad(SDVariable input, int[][] padding, double constant){ - return pad(input, sd.constant(Nd4j.createFromArray(padding)), constant); - } - - /** - * Perform padding on the given array, where padded values are the specified constant.
- * Example:
- * Input array:
- * [1, 2]
- * [3, 4]
- * Padding array:
- * [2, 0]
- * [1, 1]
- * Contant = 0
- * Result:
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * [0, 1, 2, 0]
- * [0, 3, 4, 0]
- *
- * - * - * @param input Input array to pad - * @param padding Padding array - * @param constant Constant to use for padded values - * @return Padded array - */ - public SDVariable pad(SDVariable input, SDVariable padding, double constant){ - return pad(null, input, padding, Pad.Mode.CONSTANT, constant); - } - - /** - * As per {@link #pad(SDVariable, SDVariable, double)} but also supports multiple {@link Pad.Mode} modes.
- * Example: - * Input array:
- * [1, 2]
- * [3, 4]
- * [5, 6]
- * Padding array:
- * [2, 0]
- * [1, 1]
- * Contant = 0
- * Result: CONSTANT mode
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
- * [0, 1, 2, 0]
- * [0, 3, 4, 0]
- * [0, 5, 6, 0]
- *
- * Result: SYMMETRIC mode
- * [3, 3, 4, 4]
- * [1, 1, 2, 2]
- * [1, 1, 2, 2]
- * [3, 3, 4, 4]
- * [5, 5, 6, 6]
- *
- * Result: REFLECT:
- * [6, 5, 6, 0]
- * [2, 3, 4, 3]
- * [2, 1, 2, 1]
- * [4, 3, 4, 3]
- * [6, 5, 6, 5]
- *
- * @param outputName - * @param input - * @param padding - * @param mode - * @param constant - * @return - */ - public SDVariable pad(String outputName, SDVariable input, SDVariable padding, Pad.Mode mode, double constant){ - SDVariable out = f().pad(input, padding, mode, constant); - return updateVariableNameAndReference(out, outputName); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled){ - return dotProductAttention(null, queries, keys, values, mask, scaled); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled){ - final SDVariable result = f().dotProductAttention(queries, keys, values, mask, scaled); - return updateVariableNameAndReference(result, name); - } - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public List dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights){ - return dotProductAttention(null, queries, keys, values, mask, scaled, withWeights); - } - - - /** - * This operation performs dot product attention on the given timeseries input with the given queries - * out = sum(similarity(k_i, q) * v_i) - * - * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q - * - * Optionally with normalization step: - * similarity(k, q) = softmax(k * q / sqrt(size(q)) - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1) - * - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to - * be 3D but can have queryCount = 1 - * - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for - * both. - * - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The - * output rank will depend on the input rank. - * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @param withWeights return attention weights as well, false -> only one output, true -> two outputs - * - * Output Arrays: - * @return [ Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount]] - */ - public List dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights){ - List result = f().dotProductAttention(queries, keys, values, mask, scaled, withWeights); - if(withWeights){ - return Collections.singletonList(updateVariableNameAndReference(result.get(0), name)); - }else{ - return Arrays.asList( - updateVariableNameAndReference(result.get(0), name), - updateVariableNameAndReference(result.get(1), name+":weights") - ); - } - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled){ - return multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled); - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled){ - final SDVariable result = f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled); - return updateVariableNameAndReference(result, name); - } - - /** - * This performs multi-headed dot product attention on the given timeseries input - * @see #multiHeadDotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - */ - public List multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights){ - return multiHeadDotProductAttention(null, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights); - } - - - /** - * This performs multi-headed dot product attention on the given timeseries input - * out = concat(head_1, head_2, ..., head_n) * Wo - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v) - * - * Optionally with normalization when calculating the attention for each head. - * - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention") - * - * This makes use of dot_product_attention OP support for rank 4 inputs. - * @see #dotProductAttention(String, SDVariable, SDVariable, SDVariable, SDVariable, boolean, boolean) - * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] - * @param Wv: input value projection weights of shape [numHeads, projectedValues, featureValues] - * @param Wo: output projection weights of shape [numHeads * projectedValues, outSize] - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @param withWeights return attention weights as well, false -> only one output, true -> two outputs - * - * Output Arrays: - * @return [ Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] - */ - public List multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights){ - List result = f().multiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights); - if(withWeights){ - return Collections.singletonList(updateVariableNameAndReference(result.get(0), name)); - }else{ - return Arrays.asList( - updateVariableNameAndReference(result.get(0), name), - updateVariableNameAndReference(result.get(1), name+":weights") - ); - } - } - - /** - * Max pooling on the input and outputs both max values and indices - * - * @param name Name of the output variable - * @param x input array - * @return output array and argmax array - */ - public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable x, Pooling2DConfig pooling2DConfig) { - SDVariable[] res = f().maxPoolWithArgmaxs(x, pooling2DConfig); - return sd.updateVariableNamesAndReferences(res, names); - } - - /** - * Batch normalization - * - * @param name Name of the output variable - * @param x 4D array - * @param scale vector for scaling factor of normalized x - * @param offset vector to shift to the normalized x - * @param dataFormat integer scalar - data format - * @param isTraining boolean scalar - is training mode - * @return y: 4D array - * batch_mean: vector - * batch_var: vector - */ - public SDVariable[] fusedBatchNorm(String[] names, SDVariable x, SDVariable scale, SDVariable offset, - SDVariable dataFormat, SDVariable isTraining) { - SDVariable[] res = f().fusedBatchNorm(x,scale,offset,dataFormat,isTraining); - return sd.updateVariableNamesAndReferences(res, names); - } + public SDNN(SameDiff sameDiff) { + super(sameDiff); + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int... axis) { + SDValidation.validateNumerical("batchNorm", "input", input); + SDValidation.validateNumerical("batchNorm", "mean", mean); + SDValidation.validateNumerical("batchNorm", "variance", variance); + SDValidation.validateNumerical("batchNorm", "gamma", gamma); + SDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + } + + /** + * Neural network batch normalization operation.
+ * For details, see https://arxiv.org/abs/1502.03167
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. + * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC + * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @return output variable for batch normalization (NUMERIC type) + */ + public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int... axis) { + SDValidation.validateNumerical("batchNorm", "input", input); + SDValidation.validateNumerical("batchNorm", "mean", mean); + SDValidation.validateNumerical("batchNorm", "variance", variance); + SDValidation.validateNumerical("batchNorm", "gamma", gamma); + SDValidation.validateNumerical("batchNorm", "beta", beta); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { + SDValidation.validateNumerical("biasAdd", "input", input); + SDValidation.validateNumerical("biasAdd", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + } + + /** + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
+ * + * @param name name May be null. Name for the output variable + * @param input 4d input variable (NUMERIC type) + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. + * Unused for 2d inputs + * @return output Output variable, after applying bias add operation (NUMERIC type) + */ + public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { + SDValidation.validateNumerical("biasAdd", "input", input); + SDValidation.validateNumerical("biasAdd", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("dotProductAttention", "queries", queries); + SDValidation.validateNumerical("dotProductAttention", "keys", keys); + SDValidation.validateNumerical("dotProductAttention", "values", values); + SDValidation.validateNumerical("dotProductAttention", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + } + + /** + * This operation performs dot product attention on the given timeseries input with the given queries
+ * out = sum(similarity(k_i, q) * v_i)
+ *
+ * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
+ *
+ * Optionally with normalization step:
+ * similarity(k, q) = softmax(k * q / sqrt(size(q))
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
+ *
+ * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
+ * be 3D but can have queryCount = 1
+ *
+ * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
+ * both.
+ *
+ * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
+ * output rank will depend on the input rank.
+ * + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] + * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] + * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] + * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], + * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, + SDVariable values, SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("dotProductAttention", "queries", queries); + SDValidation.validateNumerical("dotProductAttention", "keys", keys); + SDValidation.validateNumerical("dotProductAttention", "values", values); + SDValidation.validateNumerical("dotProductAttention", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Dropout operation
+ * + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public SDVariable dropout(SDVariable input, double inputRetainProbability) { + SDValidation.validateNumerical("dropout", "input", input); + return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + } + + /** + * Dropout operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @return output Output (NUMERIC type) + */ + public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { + SDValidation.validateNumerical("dropout", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable elu(SDVariable x) { + SDValidation.validateNumerical("elu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + } + + /** + * Element-wise exponential linear unit (ELU) function:
+ * out = x if x > 0
+ * out = a * (exp(x) - 1) if x <= 0
+ * with constant a = 1.0
+ *


+ * See: https://arxiv.org/abs/1511.07289
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable elu(String name, SDVariable x) { + SDValidation.validateNumerical("elu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable gelu(SDVariable x) { + SDValidation.validateNumerical("gelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the sigmoid approximation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable gelu(String name, SDVariable x) { + SDValidation.validateNumerical("gelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardSigmoid(SDVariable x) { + SDValidation.validateNumerical("hardSigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise hard sigmoid function:
+ * out[i] = 0 if in[i] <= -2.5
+ * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
+ * out[i] = 1 if in[i] >= 2.5
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardSigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("hardSigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanh(SDVariable x) { + SDValidation.validateNumerical("hardTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + } + + /** + * Element-wise hard tanh function:
+ * out[i] = -1 if in[i] <= -1
+ * out[1] = in[i] if -1 < in[i] < 1
+ * out[i] = 1 if in[i] >= 1
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanh(String name, SDVariable x) { + SDValidation.validateNumerical("hardTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanhDerivative(SDVariable x) { + SDValidation.validateNumerical("hardTanhDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + } + + /** + * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable hardTanhDerivative(String name, SDVariable x) { + SDValidation.validateNumerical("hardTanhDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, + boolean channelsFirst, int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + SDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, + boolean channelsFirst, int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + SDValidation.validateNumerical("layerNorm", "bias", bias); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, + int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + } + + /** + * Apply Layer Normalization
+ *
+ * y = gain * standardize(x) + bias
+ * + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, + int... dimensions) { + SDValidation.validateNumerical("layerNorm", "input", input); + SDValidation.validateNumerical("layerNorm", "gain", gain); + Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyRelu(SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyRelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + } + + /** + * Element-wise leaky ReLU function:
+ * out = x if x >= 0.0
+ * out = alpha * x if x < cutoff
+ * Alpha value is most commonly set to 0.01
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyRelu(String name, SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyRelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyReluDerivative(SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyReluDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + } + + /** + * Leaky ReLU derivative: dOut/dIn given input.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param alpha Cutoff - commonly 0.01 + * @return output Output variable (NUMERIC type) + */ + public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { + SDValidation.validateNumerical("leakyReluDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("linear", "input", input); + SDValidation.validateNumerical("linear", "weights", weights); + SDValidation.validateNumerical("linear", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + } + + /** + * Linear layer operation: out = mmul(in,w) + bias
+ * Note that bias array is optional
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("linear", "input", input); + SDValidation.validateNumerical("linear", "weights", weights); + SDValidation.validateNumerical("linear", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSigmoid(SDVariable x) { + SDValidation.validateNumerical("logSigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable logSigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("logSigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log softmax activation
+ * + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable logSoftmax(SDVariable x) { + SDValidation.validateNumerical("logSoftmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + } + + /** + * Log softmax activation
+ * + * @param name name May be null. Name for the output variable + * @param x (NUMERIC type) + * @return output (NUMERIC type) + */ + public SDVariable logSoftmax(String name, SDVariable x) { + SDValidation.validateNumerical("logSoftmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Log softmax activation
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public SDVariable logSoftmax(SDVariable x, int dimension) { + SDValidation.validateNumerical("logSoftmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + } + + /** + * Log softmax activation
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply log softmax + * @return output Output - log(softmax(input)) (NUMERIC type) + */ + public SDVariable logSoftmax(String name, SDVariable x, int dimension) { + SDValidation.validateNumerical("logSoftmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, + SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + SDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + } + + /** + * This performs multi-headed dot product attention on the given timeseries input
+ * out = concat(head_1, head_2, ..., head_n) * Wo
+ * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
+ *
+ * Optionally with normalization when calculating the attention for each head.
+ *
+ * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
+ *
+ * This makes use of dot_product_attention OP support for rank 4 inputs.
+ * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
+ * + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] + * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + */ + public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, + SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, + SDVariable mask, boolean scaled) { + SDValidation.validateNumerical("multiHeadDotProductAttention", "queries", queries); + SDValidation.validateNumerical("multiHeadDotProductAttention", "keys", keys); + SDValidation.validateNumerical("multiHeadDotProductAttention", "values", values); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", Wq); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", Wk); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); + SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); + SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(SDVariable input, SDVariable padding, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + } + + /** + * Padding operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public SDVariable prelu(SDVariable input, SDVariable alpha, int... sharedAxes) { + SDValidation.validateNumerical("prelu", "input", input); + SDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + } + + /** + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) + * @return output Output (NUMERIC type) + */ + public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int... sharedAxes) { + SDValidation.validateNumerical("prelu", "input", input); + SDValidation.validateNumerical("prelu", "alpha", alpha); + Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu(SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + } + + /** + * Element-wise rectified linear function with specified cutoff:
+ * out[i] = in[i] if in[i] >= cutoff
+ * out[i] = 0 otherwise
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu(String name, SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu6(SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu6", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + } + + /** + * Element-wise "rectified linear 6" function with specified cutoff:
+ * out[i] = min(max(in, cutoff), 6)
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation. Usually 0 + * @return output Output (NUMERIC type) + */ + public SDVariable relu6(String name, SDVariable x, double cutoff) { + SDValidation.validateNumerical("relu6", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("reluLayer", "input", input); + SDValidation.validateNumerical("reluLayer", "weights", weights); + SDValidation.validateNumerical("reluLayer", "bias", bias); + return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + } + + /** + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
+ * Note that bias array is optional
+ * + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param weights Weights variable (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { + SDValidation.validateNumerical("reluLayer", "input", input); + SDValidation.validateNumerical("reluLayer", "weights", weights); + SDValidation.validateNumerical("reluLayer", "bias", bias); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable selu(SDVariable x) { + SDValidation.validateNumerical("selu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + } + + /** + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
+ *
+ * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
+ * Uses default scale and alpha values.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable selu(String name, SDVariable x) { + SDValidation.validateNumerical("selu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sigmoid(SDVariable x) { + SDValidation.validateNumerical("sigmoid", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + } + + /** + * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sigmoid(String name, SDVariable x) { + SDValidation.validateNumerical("sigmoid", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { + SDValidation.validateNumerical("sigmoidDerivative", "x", x); + SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + } + + /** + * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
+ * + * @param name name May be null. Name for the output variable + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @return output Output (gradient at input of sigmoid) (NUMERIC type) + */ + public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { + SDValidation.validateNumerical("sigmoidDerivative", "x", x); + SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(SDVariable x, int dimension) { + SDValidation.validateNumerical("softmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @param dimension Dimension along which to apply softmax + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(String name, SDVariable x, int dimension) { + SDValidation.validateNumerical("softmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(SDVariable x) { + SDValidation.validateNumerical("softmax", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + } + + /** + * Softmax activation, along the specified dimension
+ * + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softmax(String name, SDVariable x) { + SDValidation.validateNumerical("softmax", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Softmax derivative function
+ * + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public SDVariable softmaxDerivative(SDVariable x, SDVariable wrt, int dimension) { + SDValidation.validateNumerical("softmaxDerivative", "x", x); + SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + } + + /** + * Softmax derivative function
+ * + * @param name name May be null. Name for the output variable + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param dimension Softmax dimension + * @return output (NUMERIC type) + */ + public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, int dimension) { + SDValidation.validateNumerical("softmaxDerivative", "x", x); + SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softplus(SDVariable x) { + SDValidation.validateNumerical("softplus", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + } + + /** + * Element-wise softplus function: out = log(exp(x) + 1)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softplus(String name, SDVariable x) { + SDValidation.validateNumerical("softplus", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softsign(SDVariable x) { + SDValidation.validateNumerical("softsign", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + } + + /** + * Element-wise softsign function: out = x / (abs(x) + 1)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable softsign(String name, SDVariable x) { + SDValidation.validateNumerical("softsign", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable softsignDerivative(SDVariable x) { + SDValidation.validateNumerical("softsignDerivative", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + } + + /** + * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable softsignDerivative(String name, SDVariable x) { + SDValidation.validateNumerical("softsignDerivative", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable swish(SDVariable x) { + SDValidation.validateNumerical("swish", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + } + + /** + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
+ * See: https://arxiv.org/abs/1710.05941
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable swish(String name, SDVariable x) { + SDValidation.validateNumerical("swish", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + } + + /** + * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable tanh(String name, SDVariable x) { + SDValidation.validateNumerical("tanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index e5cfef684..88792bddb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -27,17 +27,21 @@ import org.nd4j.autodiff.samediff.SameDiff; */ public abstract class SDOps { - protected final SameDiff sd; + protected final SameDiff sd; - public SDOps(SameDiff sameDiff) { - this.sd = sameDiff; - } + public SDOps() { + sd = null; + } - protected DifferentialFunctionFactory f() { - return sd.f(); - } + public SDOps(SameDiff sameDiff) { + this.sd = sameDiff; + } - protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { - return sd.updateVariableNameAndReference(varToUpdate, newVarName); - } + protected DifferentialFunctionFactory f() { + return sd.f(); + } + + protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { + return sd.updateVariableNameAndReference(varToUpdate, newVarName); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index de0114b92..6b1831de7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,198 +14,232 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import java.lang.String; + import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*; - -import java.util.Arrays; -import java.util.List; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; -import org.nd4j.linalg.primitives.Pair; -/** - * SameDiff Recurrent Neural Network operations
- * Accessible via {@link SameDiff#rnn()}
- * See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.
- * See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.
- * - * @author Alex Black - */ public class SDRNN extends SDOps { - public SDRNN(SameDiff sameDiff) { - super(sameDiff); - } + public SDRNN(SameDiff sameDiff) { + super(sameDiff); + } + /** + * The GRU cell. Does a single time step operation
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) + * @param GRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable(); + } - /** - * See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}. - */ - public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { - GRUCell c = new GRUCell(sd, x, hLast, weights); - return new GRUCellOutputs(c.outputVariables()); - } + /** + * The GRU cell. Does a single time step operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) + * @param GRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + GRUCell c = new GRUCell(sd,x, hLast, GRUWeights); + return new GRUCellOutputs(c.outputVariables(name)); + } - /** - * The GRU cell. Does a single time step operation. - * - * @param baseName The base name for the gru cell - * @param x Input, with shape [batchSize, inSize] - * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] - * @param weights The cell's weights. - * @return The cell's outputs. - */ - public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { - GRUCell c = new GRUCell(sd, x, hLast, weights); - return new GRUCellOutputs(c.outputVariables(baseName)); - } + /** + * The LSTM cell. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The cell's outputs (NUMERIC type) + */ + public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, + LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmCell", "x", x); + SDValidation.validateNumerical("lstmCell", "cLast", cLast); + SDValidation.validateNumerical("lstmCell", "yLast", yLast); + LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); + return new LSTMCellOutputs(c.outputVariables()); + } - /** - * See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}. - */ - public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - LSTMWeights weights, LSTMConfiguration config){ - LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); - return new LSTMCellOutputs(c.outputVariables()); - } + /** + * The LSTM cell. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The cell's outputs (NUMERIC type) + */ + public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast, + LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmCell", "x", x); + SDValidation.validateNumerical("lstmCell", "cLast", cLast); + SDValidation.validateNumerical("lstmCell", "yLast", yLast); + LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration); + return new LSTMCellOutputs(c.outputVariables(name)); + } - /** - * The LSTM cell. Does a single time step operation. - * - * @param baseName The base name for the lstm cell - * @param x Input, with shape [batchSize, inSize] - * @param cLast Previous cell state, with shape [batchSize, numUnits] - * @param yLast Previous cell output, with shape [batchSize, numUnits] - * @param weights The cell's weights. - * @param config The cell's config. - * @return The cell's outputs. - */ - public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); - return new LSTMCellOutputs(c.outputVariables(baseName)); - } + /** + * The LSTM layer. Does multiple time steps.
+ * + * @param maxTSLength (NUMERIC type) + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast, + SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, - @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); - return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat()); - } + /** + * The LSTM layer. Does multiple time steps.
+ * + * @param name name May be null. Name for the output variable + * @param maxTSLength (NUMERIC type) + * @param x Input, with shape dependent on the data format (in config). (NUMERIC type) + * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type) + * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type) + * @param LSTMWeights Configuration Object + * @param LSTMConfiguration Configuration Object + * @return output The layer's outputs. (NUMERIC type) + */ + public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, + SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) { + SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength); + SDValidation.validateNumerical("lstmLayer", "x", x); + SDValidation.validateNumerical("lstmLayer", "cLast", cLast); + SDValidation.validateNumerical("lstmLayer", "yLast", yLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - return lstmLayer( - sd.scalar("lstm_max_ts_length", maxTSLength), - x, cLast, yLast, weights, config); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDValidation.validateNumerical("sru", "mask", mask); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable(); + } - /** - * See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} - */ - public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - if(baseName != null) { - return lstmLayer(baseName, - sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), - x, cLast, yLast, weights, config); - } else { - return lstmLayer(maxTSLength, x, cLast, yLast, weights, config); - } - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(String name, SDVariable x, SDVariable initialC, SDVariable mask, + SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDValidation.validateNumerical("sru", "mask", mask); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * The LSTM layer. Does multiple time steps. - * - * Input shape depends on data format (in config):
- * TNS -> [timeSteps, batchSize, inSize]
- * NST -> [batchSize, inSize, timeSteps]
- * NTS -> [batchSize, timeSteps, inSize]
- * - * @param baseName The base name for the lstm layer - * @param x Input, with shape dependent on the data format (in config). - * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] - * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] - * @param weights The layer's weights. - * @param config The layer's config. - * @return The layer's outputs. - */ - public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength, - @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, - @NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ - LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); - return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat()); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(SDVariable x, SDVariable initialC, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable(); + } - /** - * See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}. - */ - public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { - return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables()); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs.. (NUMERIC type) + */ + public SDVariable sru(String name, SDVariable x, SDVariable initialC, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sru", "x", x); + SDValidation.validateNumerical("sru", "initialC", initialC); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * The SRU cell. Does a single time step operation. - * - * @param baseName The base name for the sru cell - * @param x Input, with shape [batchSize, inSize] - * @param cLast Previous cell state, with shape [batchSize, inSize] - * @param weights The cell's weights. - * @return The cell's outputs. - */ - public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { - return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName)); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { - return sru(x, initialC, null, weights); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) { - return sru(baseName, x, initialC, null, weights); - } - - /** - * See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)} - */ - public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { - return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables()); - } - - /** - * The SRU layer. Does a single time step operation. - * - * @param baseName The base name for the sru layer - * @param x Input, with shape [batchSize, inSize, timeSeriesLength] - * @param initialC Initial cell state, with shape [batchSize, inSize] - * @param mask An optional dropout mask, with shape [batchSize, inSize] - * @param weights The layer's weights. - * @return The layer's outputs. - */ - public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) { - return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName)); - } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable sruCell(SDVariable x, SDVariable cLast, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sruCell", "x", x); + SDValidation.validateNumerical("sruCell", "cLast", cLast); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable(); + } + /** + * The SRU layer. Does a single time step operation.
+ * + * @param name name May be null. Name for the output variable + * @param x Input, with shape [batchSize, inSize] (NUMERIC type) + * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type) + * @param SRUWeights Configuration Object + * @return output The cell's outputs. (NUMERIC type) + */ + public SDVariable sruCell(String name, SDVariable x, SDVariable cLast, SRUWeights SRUWeights) { + SDValidation.validateNumerical("sruCell", "x", x); + SDValidation.validateNumerical("sruCell", "cLast", cLast); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index cabf41103..cd986d7bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,324 +14,253 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; -import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; - -/** - * SameDiff random number generator operations
- * Accessible via {@link SameDiff#random()} - * - * @author Alex Black - */ public class SDRandom extends SDOps { + public SDRandom(SameDiff sameDiff) { + super(sameDiff); + } - public SDRandom(SameDiff sd) { - super(sd); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable bernoulli(double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable(); + } - /** - * @see #bernoulli(String, double, SDVariable) - */ - public SDVariable bernoulli(double p, SDVariable shape) { - return bernoulli(null, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,
+ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability
+ * 1-P.
+ * + * @param name name May be null. Name for the output variable + * @param p Probability of value 1 + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable bernoulli(String name, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, - * with the specified probability. Array values will have value 1 with probability P and value 0 with probability - * 1-P.
- * See {@link #bernoulli(String, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param p Probability of value 1 - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable bernoulli(String name, double p, SDVariable shape) { - validateInteger("bernoulli random", shape); - SDVariable ret = f().randomBernoulli(p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable binomial(int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable(); + } - /** - * @see #bernoulli(String, double, long...) - */ - public SDVariable bernoulli(double p, long... shape) { - return bernoulli(null, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,
+ * with the specified number of trials and probability.
+ * + * @param name name May be null. Name for the output variable + * @param nTrials Number of trials parameter for the binomial distribution + * @param p Probability of success for each trial + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable binomial(String name, int nTrials, double p, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, - * with the specified probability. Array values will have value 1 with probability P and value 0 with probability - * 1-P.
- * See {@link #bernoulli(String, double, SDVariable)} for the equivalent function where the shape is - * specified as a SDVarible instead - * - * @param name Name of the new SDVariable - * @param p Probability of value 1 - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable bernoulli(String name, double p, long... shape) { - SDVariable ret = f().randomBernoulli(p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable exponential(double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + return new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable(); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, - * with the specified number of trials and probability. - * - * @param nTrials Number of trials parameter for the binomial distribution - * @param p Probability of success for each trial - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable binomial(int nTrials, double p, long... shape) { - return binomial(null, nTrials, p, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:
+ * P(x) = lambda * exp(-lambda * x)
+ * + * Inputs must satisfy the following constraints:
+ * Must be positive: lambda > 0
+ * + * @param name name May be null. Name for the output variable + * @param lambda lambda parameter + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable exponential(String name, double lambda, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(lambda > 0, "Must be positive"); + SDVariable out = new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, - * with the specified number of trials and probability. - * - * @param name Name of the new SDVariable - * @param nTrials Number of trials parameter for the binomial distribution - * @param p Probability of success for each trial - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable binomial(String name, int nTrials, double p, long... shape) { - SDVariable ret = f().randomBinomial(nTrials, p, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable logNormal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: - * P(x) = lambda * exp(-lambda * x) - * - * @param lambda Must be > 0 - * @param shape Shape of the output - * @return new SDVariable - */ - public SDVariable exponential(double lambda, SDVariable shape) { - return exponential(null, lambda, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,
+ * i.e., {@code log(x) ~ N(mean, stdev)}
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable logNormal(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: - * P(x) = lambda * exp(-lambda * x) - * - * @param name Name of the output variable - * @param lambda Must be > 0 - * @param shape Shape of the new variable - * @return new SDVaribale - */ - public SDVariable exponential(String name, double lambda, SDVariable shape) { - validateInteger("exponential random", shape); - SDVariable ret = f().randomExponential(lambda, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normal(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * @see #logNormal(String, double, double, long...) - */ - public SDVariable logNormal(double mean, double stddev, long... shape) { - return logNormal(null, mean, stddev, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev)
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normal(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution, - * i.e., {@code log(x) ~ N(mean, stdev)}
- * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable logNormal(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomLogNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normalTruncated(double mean, double stddev, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + } - /** - * @see #normal(String, double, double, SDVariable) - */ - public SDVariable normal(double mean, double stddev, SDVariable shape) { - return normal(null, mean, stddev, shape); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,
+ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
+ * + * @param name name May be null. Name for the output variable + * @param mean Mean value for the random array + * @param stddev Standard deviation for the random array + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable normalTruncated(String name, double mean, double stddev, DataType datatype, + long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev)
- * See {@link #normal(String, double, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable, as a 1D array - * @return New SDVariable - */ - public SDVariable normal(String name, double mean, double stddev, SDVariable shape) { - validateInteger("normal (Gaussian) random", shape); - SDVariable ret = f().randomNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #normal(String, double, double, long...) - */ - public SDVariable normal(double mean, double stddev, long... shape) { - return normal(null, mean, stddev, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev)
- * See {@link #normal(String, double, double, SDVariable)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable normal(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomNormal(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #normalTruncated(String, double, double, long...) - */ - public SDVariable normalTruncated(double mean, double stddev, long... shape) { - return normalTruncated(null, mean, stddev, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, - * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled
- * - * @param name Name of the new SDVariable - * @param mean Mean value for the random array - * @param stddev Standard deviation for the random array - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable normalTruncated(String name, double mean, double stddev, long... shape) { - SDVariable ret = f().randomNormalTruncated(mean, stddev, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #uniform(String, double, double, SDVariable) - */ - public SDVariable uniform(double min, double max, SDVariable shape) { - return uniform(null, min, max, shape); - } - - /** - * @see #uniform(String, double, double, SDVariable) - */ - public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) { - return uniform(null, min, max, shape, dataType); - } - - /** - * As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output - */ - public SDVariable uniform(String name, double min, double max, SDVariable shape) { - return uniform(name, min, max, shape, null); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, - * U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned
- * See {@link #uniform(double, double, long...)} for the equivalent function where the shape is - * specified as a long[] instead - * - * @param name Name of the new SDVariable - * @param min Minimum value - * @param max Maximum value. Must satisfy max >= min - * @param shape Shape of the new random SDVariable, as a 1D array - * @param dataType Data type of the output array (if null: Float32 output is returned) - * @return New SDVariable, of the specified data type - */ - public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) { - validateInteger("uniform random", shape); - SDVariable ret = f().randomUniform(min, max, shape, dataType); - return updateVariableNameAndReference(ret, name); - } - - /** - * @see #uniform(String, double, double, long...) - */ - public SDVariable uniform(double min, double max, long... shape) { - return uniform(null, min, max, shape); - } - - /** - * Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution, - * U(min,max)
- * See {@link #uniform(double, double, long...)} for the equivalent function where the shape is - * specified as a SDVariable instead - * - * @param name Name of the new SDVariable - * @param min Minimum value - * @param max Maximum value. Must satisfy max >= min - * @param shape Shape of the new random SDVariable - * @return New SDVariable - */ - public SDVariable uniform(String name, double min, double max, long... shape) { - SDVariable ret = f().randomUniform(min, max, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable with Gamma distribution - * - * @param name Name of the output variable - * @param alpha distribution parameter - * @param beta distribution parameter - * @param shape Shape of the new variable - * @return new SDVariable - */ - public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) { - SDVariable ret = f().randomGamma(alpha, beta, shape); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable with Poission distribution - * - * @param name Name of the output variable - * @param lambda rate distribution parameter - * @param shape Shape of the new variable - * @return new SDVariable - */ - public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) { - SDVariable ret = f().randomPoisson(shape, lambda, seeds); - return updateVariableNameAndReference(ret, name); - } - - /** - * Generate a new random SDVariable by random shuffle - * - * @param name Name of the output variable - * @param value array to shuffle - * @return new SDVariable - */ - public SDVariable shuffle(String name, SDVariable value, int... seeds) { - SDVariable ret = f().randomShuffle(value, seeds); - return updateVariableNameAndReference(ret, name); - } + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable uniform(double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + return new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable(); + } + /** + * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,
+ * U(min,max)
+ * + * @param name name May be null. Name for the output variable + * @param min Minimum value + * @param max Maximum value. + * @param datatype Data type of the output variable + * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) + */ + public SDVariable uniform(String name, double min, double max, DataType datatype, long... shape) { + Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java index f6434a56f..93999f0fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDValidation.java @@ -55,6 +55,15 @@ public class SDValidation { v.name() + "\" with non-integer data type " + v.dataType()); } + protected static void validateNumerical(String opName, String inputName, SDVariable[] vars) { + for (SDVariable v : vars) { + if (v == null) continue; + if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + + v.name() + "\" with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on numerical SDVariables (not boolean or utf8). * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays @@ -97,6 +106,16 @@ public class SDValidation { v.name() + "\" with non-integer data type " + v.dataType()); } + protected static void validateInteger(String opName, String inputName, SDVariable[] vars) { + for (SDVariable v : vars) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + + v.name() + "\" with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on an floating point type SDVariable * @@ -200,4 +219,18 @@ public class SDValidation { } } + public static boolean isSameType(SDVariable x, SDVariable y) { + return x.dataType() == y.dataType(); + } + + public static boolean isSameType(SDVariable[] x) { + DataType firstDataType = x[0].dataType(); + if (x.length > 1) { + for (int i = 1; i < x.length; ++i) { + if (firstDataType != x[i].dataType()) + return false; + } + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java similarity index 95% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java index fb3fc9c67..c42795070 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/enums/DataFormat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/DataFormat.java @@ -16,7 +16,7 @@ //================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== -package org.nd4j.linalg.factory.enums; +package org.nd4j.enums; /** * Data format: "NCHW" or "NHWC" */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index ebe27bd85..62edb778f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -633,7 +633,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Lu.class, org.nd4j.linalg.api.ops.custom.TriangularSolve.class, org.nd4j.linalg.api.ops.custom.LinearSolve.class, - org.nd4j.linalg.api.ops.custom.Lstsq.class + org.nd4j.linalg.api.ops.custom.Lstsq.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.Qr.class, + org.nd4j.linalg.api.ops.custom.Logdet.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 2a72fc76e..8b598242c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -85,6 +85,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum this(x, null, dimensions); } + public BaseIndexAccumulation(INDArray x, boolean keepDims, int[] dimensions) { + this(x, null, dimensions); + this.keepDims = keepDims; + defineDimensions(dimensions); + } + public BaseIndexAccumulation(INDArray x, INDArray z, int[] dimensions) { super(x, z); defineDimensions(dimensions); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 55d551369..f842303ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -29,12 +29,17 @@ public class AdjustContrast extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrast(@NonNull INDArray in, double factor) { + this(in, factor, null); + } + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { super(sameDiff,new SDVariable[]{in,factor}); } - public AdjustContrast(@NonNull INDArray in, double factor) { - this(in, factor, null); + public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff,new SDVariable[]{in}); + addTArgument(factor); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java index e1a5b0a7a..bd0c88792 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustHue.java @@ -50,6 +50,11 @@ public class AdjustHue extends DynamicCustomOp { super(sameDiff,new SDVariable[]{in,factor}); } + public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff,new SDVariable[]{in}); + addTArgument(factor); + } + @Override public String opName() { return "adjust_hue"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java index e9f1f90c8..3c98f2149 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustSaturation.java @@ -49,6 +49,11 @@ public class AdjustSaturation extends DynamicCustomOp { super(sameDiff, new SDVariable[]{in, factor}); } + public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) { + super(sameDiff, new SDVariable[]{in}); + addTArgument(factor); + } + @Override public String opName() { return "adjust_saturation"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java new file mode 100644 index 000000000..81b8cbc08 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Logdet.java @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Logdet extends DynamicCustomOp { + + public Logdet(INDArray input) { + addInputArgument(input); + } + + public Logdet(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "logdet"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java index 20751164f..b7c0e4092 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java @@ -17,9 +17,17 @@ package org.nd4j.linalg.api.ops.custom; import lombok.NoArgsConstructor; import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + @NoArgsConstructor public class Lstsq extends DynamicCustomOp { @@ -33,8 +41,21 @@ public class Lstsq extends DynamicCustomOp { this(matrix, rhs, 0.0, true); } + public Lstsq(@NonNull SameDiff sameDiff, @NonNull SDVariable matrix, @NonNull SDVariable rhs, double l2_regularizer, boolean fast) { + super(sameDiff, new SDVariable[]{matrix,rhs}); + addTArgument(l2_regularizer); + addBArgument(fast); + } + @Override public String opName() { return "lstsq"; } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java index 554781958..40d50afe3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/MatrixBandPart.java @@ -15,6 +15,7 @@ ******************************************************************************/ package org.nd4j.linalg.api.ops.custom; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -26,10 +27,9 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; import java.util.List; +@NoArgsConstructor public class MatrixBandPart extends DynamicCustomOp { - public MatrixBandPart() {} - public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) { Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher"); long N = input.size(-2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index 97b826064..da0896a46 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -1,6 +1,5 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -37,7 +36,6 @@ import java.util.*; */ @NoArgsConstructor public class CropAndResize extends DynamicCustomOp { - public enum Method {BILINEAR, NEAREST}; protected Method method = Method.BILINEAR; protected double extrapolationValue = 0.0; @@ -50,6 +48,10 @@ public class CropAndResize extends DynamicCustomOp { addArgs(); } + public CropAndResize(@NonNull SameDiff sameDiff, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, + SDVariable cropOutSize, double extrapolationValue) { + this(sameDiff, image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue); + } public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, @@ -65,12 +67,10 @@ public class CropAndResize extends DynamicCustomOp { outputArguments.add(output); } - public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, - @NonNull INDArray cropOutSize, double extrapolationValue) { - this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null); + public CropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, INDArray cropOutSize, double extrapolationValue ) { + this(image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue, null); } - @Override public String opName() { return "crop_and_resize"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java index 71b8b1fb2..5e6362d67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ExtractImagePatches.java @@ -46,6 +46,12 @@ public class ExtractImagePatches extends DynamicCustomOp { public ExtractImagePatches(){ } + public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, + int kH, int kW, int sH, int sW, int rH, int rW, + boolean sameMode) { + this(samediff, input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH,rW}, sameMode); + + } public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, @NonNull int[] kSizes, @NonNull int[] strides, @NonNull int[] rates, boolean sameMode){ super(samediff, input); @@ -72,16 +78,8 @@ public class ExtractImagePatches extends DynamicCustomOp { addArgs(); } - public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { - super(new INDArray[]{input},null); - int[] kSises = {kH,kW}; - int[] strides = {sH,sW}; - int[] rates = {rH, rW}; - this.kSizes = kSises; - this.strides = strides; - this.rates = rates; - this.isSameMode = sameMode; - addArgs(); + public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { + this(input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index f8763c41a..f7ab95d77 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -42,6 +42,13 @@ public class NonMaxSuppression extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); } + public NonMaxSuppression(SameDiff sameDiff, SDVariable boxes, SDVariable scores, int maxOutSize, + double iouThreshold, double scoreThreshold) { + super(null, sameDiff, new SDVariable[]{boxes, scores}, false); + addIArgument(maxOutSize); + addTArgument(iouThreshold, scoreThreshold); + } + public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) { addInputArgument(boxes,scores); addIArgument(maxOutSize); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index 60b278ed7..dd61c03e4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -54,10 +54,18 @@ public class FirstIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } + public FirstIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) { + this(sameDiff, i_v, condition, keepDims, dimensions); + } + public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { this(x, condition, false, dimension); } + public FirstIndex(INDArray x, boolean keepDims, @NonNull Condition condition, int... dimension) { + this(x,condition,keepDims,dimension); + } + public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) { this(x, condition, Nd4j.EPS_THRESHOLD, dimension); this.keepDims = keepDims; @@ -72,7 +80,6 @@ public class FirstIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } - @Override public int opNum() { return 4; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 7280d7adf..8b7872b49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -45,6 +45,11 @@ public class IMax extends BaseIndexAccumulation { super(x, z, dimensions); } + public IMax(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + + } + public IMax(INDArray x, int... dimensions) { super(x, null, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index 449ea36a0..06b3deb1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -44,6 +44,10 @@ public class IMin extends BaseIndexAccumulation { super(x, dimensions); } + public IMin(INDArray x, boolean keepDims, int... dimensions) { + super(x, keepDims, dimensions); + } + public IMin(INDArray x, INDArray z, int... dimensions) { super(x, z, dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index e77d42398..1325d33c5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; import lombok.Data; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -38,12 +39,16 @@ import java.util.Map; * @author raver119@gmail.com */ @Data +@NoArgsConstructor public class LastIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; protected double eps; protected int mode; + public LastIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) { + this(sameDiff, i_v, condition, keepDims, dimensions); + } public LastIndex(SameDiff sameDiff, SDVariable i_v, Condition condition, boolean keepDims, int... dimensions) { super(sameDiff, i_v, keepDims, dimensions); this.condition = condition; @@ -53,13 +58,19 @@ public class LastIndex extends BaseIndexAccumulation { this.extraArgs = new Object[] {compare, eps, (double) mode}; } - public LastIndex() {} - + public LastIndex(SameDiff sameDiff, SDVariable x, @NonNull Condition condition, int... dimensions) { + super(sameDiff, x, false, dimensions); + this.condition = condition; + } public LastIndex(INDArray x, @NonNull Condition condition, int... dimensions) { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); } + public LastIndex(INDArray in, boolean keepDim, Condition condition, int... dimensions) { + this(in, condition, keepDim, dimensions); + } + public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) { this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); this.keepDims = keepDim; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java index 79bcacab0..7c6b5186c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java @@ -47,10 +47,6 @@ public class AvgPooling3D extends Pooling3D { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG); } - public AvgPooling3D(SameDiff sameDiff,INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { - super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG); - } - public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) { super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index e3716bc24..b4f881fac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -76,6 +76,19 @@ public class BatchNorm extends DynamicCustomOp { addArgs(); } + public BatchNorm(SameDiff sameDiff, SDVariable input, SDVariable mean, SDVariable variance, + SDVariable gamma, SDVariable beta, double epsilon, int[] axis) { + super(null,sameDiff, wrapFilterNull(input, mean, variance, gamma, beta), false); + Preconditions.checkState(axis != null && axis.length > 0, "Invalid axis argument: axis must be specified" + + "and length > 0. Got %s", axis); + this.sameDiff = sameDiff; + this.applyBeta = beta != null; + this.applyGamma = gamma != null; + this.epsilon = epsilon; + this.jaxis = axis; + addArgs(); + } + public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){ super(wrapFilterNull(input, mean, variance, gamma, beta), null); this.jaxis = axis; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 819d1d10c..e33037738 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -46,6 +46,10 @@ public class Conv1D extends DynamicCustomOp { protected Conv1DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; + public Conv1D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv1DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv1D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -64,12 +68,8 @@ public class Conv1D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) { - this(wrapFilterNull(input, weights, bias), null, conv1DConfig); - } - - public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) { - this(new INDArray[]{input, weights}, null, conv1DConfig); + public Conv1D(INDArray input, INDArray weights, INDArray bias, Conv1DConfig config) { + this(input, weights, bias, null, config); } private void initConfig(Conv1DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 60bdfbfcc..9635c6f36 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -56,6 +56,11 @@ public class Conv2D extends DynamicCustomOp { protected Conv2DConfig config; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; + public Conv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -75,12 +80,8 @@ public class Conv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) { - this(new INDArray[]{layerInput, weights}, null, conv2DConfig); - } - - public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig); + public Conv2D(INDArray layerInput, INDArray weights, INDArray bias, Conv2DConfig config) { + this(layerInput, weights, bias, null, config); } protected void initConfig(Conv2DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java index 94fb897b0..bb30930d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv3D.java @@ -55,6 +55,11 @@ public class Conv3D extends DynamicCustomOp { public Conv3D() { } + public Conv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, @NonNull Conv3DConfig config) { + this(sameDiff, wrapFilterNull(input, weights, bias), config); + } + @Builder(builderMethodName = "sameDiffBuilder") public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) { super(sameDiff, inputFunctions); @@ -70,12 +75,12 @@ public class Conv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) { - this(new INDArray[]{input, weights}, null, conv3DConfig); + public Conv3D(INDArray input, INDArray weights, INDArray bias, Conv3DConfig config) { + this(wrapFilterNull(input, weights, bias), null, config); } - public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) { - this(wrapFilterNull(input, weights, bias) , null, conv3DConfig); + public Conv3D(INDArray input, INDArray weights, Conv3DConfig config) { + this(wrapFilterNull(input, weights), null, config); } private void initConfig(Conv3DConfig config){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index f3500bec0..74b448dc8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -52,6 +52,11 @@ public class DeConv2D extends DynamicCustomOp { protected DeConv2DConfig config; + public DeConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, + SDVariable bias, DeConv2DConfig config) { + this(sameDiff, wrapFilterNull(input, weights, bias), config); + } + @Builder(builderMethodName = "sameDiffBuilder") public DeConv2D(SameDiff sameDiff, SDVariable[] inputs, @@ -73,15 +78,10 @@ public class DeConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) { - this(wrapFilterNull(layerInput, weights), null, deConv2DConfig); + public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig config) { + this(layerInput, weights, bias, null, config); } - public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) { - this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig); - } - - @Override public long[] iArgs() { if (iArguments.size() == 0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index a4652850c..436659443 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -48,12 +48,18 @@ public class DeConv3D extends DynamicCustomOp { protected DeConv3DConfig config; - public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { super(sameDiff, toArr(input, weights, bias)); this.config = config; addArgs(); } + public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) { + super(sameDiff, toArr(input, weights, null)); + this.config = config; + addArgs(); + } + public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){ super(inputs, outputs); @@ -65,12 +71,8 @@ public class DeConv3D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) { - this(new INDArray[]{input, weights}, null, deConv3DConfig); - } - - public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) { - this(wrapFilterNull(input, weights, bias), null, deConv3DConfig); + public DeConv3D(INDArray input, INDArray weights, INDArray bias, DeConv3DConfig config) { + this(input, weights, bias, null, config); } private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index 3becef510..20808dff5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -16,16 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -46,45 +45,48 @@ import java.util.*; * @author raver119@gmail.com, Max Pumperla */ public class DepthToSpace extends DynamicCustomOp { - private String dataFormat = "NHWC"; + private DataFormat dataFormat = DataFormat.NHWC; private int blockSize; public DepthToSpace() { } - public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { + public DepthToSpace(SameDiff sameDiff, SDVariable args, int blockSize, DataFormat dataFormat) { + this(sameDiff, new SDVariable[]{args}, blockSize, dataFormat); + } + + public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) { super(null, sameDiff, args, false); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) { + public DepthToSpace(INDArray in, INDArray out, int blockSize, DataFormat dataFormat) { super(null, in, out, null, null); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) { - this(x, null, blockSize, dataFormat.toString()); + public DepthToSpace(INDArray in, int blockSize, DataFormat dataFormat) { + this(in, null, blockSize, dataFormat); } - @Override public List doDiff(List i_v) { // Gradient to DepthToSpace is just SpaceToDepth of same block size and data format. SDVariable gradient = i_v.get(0); - SDVariable ret = sameDiff.cnn().spaceToDepth(gradient, blockSize, dataFormat); + SDVariable ret = new SpaceToDepth(sameDiff, new SDVariable[]{gradient}, blockSize, dataFormat).outputVariable(); return Arrays.asList(ret); } @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index ab42c3c5a..afb51af58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -16,8 +16,11 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.*; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -49,11 +52,15 @@ import java.util.*; */ @Slf4j @Getter -@NoArgsConstructor public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; + public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, + @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } + @Builder(builderMethodName = "sameDiffBuilder") public DepthwiseConv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -75,16 +82,11 @@ public class DepthwiseConv2D extends DynamicCustomOp { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } - public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig); + public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig config) { + this(layerInput, depthWeights, bias, null, config); } - public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig); - } - - public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) { - this(wrapFilterNull(inputs), null, conv2DConfig); + public DepthwiseConv2D() { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index cc5780a7c..10108d87c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -58,6 +58,10 @@ public class LocalResponseNormalization extends DynamicCustomOp { addArgs(); } + public LocalResponseNormalization(SameDiff sameDiff, SDVariable input, LocalResponseNormalizationConfig config) { + this(sameDiff, new SDVariable[]{input}, false, config); + } + public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){ super(new INDArray[]{input}, wrapOrNull(output)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index 9f7c9bfb7..c54b63aa7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -60,15 +60,16 @@ public class MaxPooling2D extends DynamicCustomOp { addArgs(); } - public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ + public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){ super(null, new INDArray[]{input}, wrapOrNull(output)); config.setType(Pooling2D.Pooling2DType.MAX); + this.config = config; addArgs(); } - public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) { - this(input, null, pooling2DConfig); + public MaxPooling2D(INDArray input, @NonNull Pooling2DConfig config){ + this(input, null, config); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java index 6c4ccaa9a..6c7aec888 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java @@ -47,8 +47,12 @@ public class MaxPooling3D extends Pooling3D { super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX); } - public MaxPooling3D(SameDiff sameDiff, INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { - super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX); + public MaxPooling3D(INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { + addInputArgument(arrayInput); + if (arrayOutput != null) + addOutputArgument(arrayOutput); + this.config = config; + addArgs(); } public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index b28b9a987..cf4e87814 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -44,18 +44,23 @@ public class SConv2D extends Conv2D { super(sameDiff, inputFunctions, conv2DConfig); } + public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights, + @NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { + this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig); + } + public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ super(inputs, outputs, config); } - public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){ - this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig); - } - public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){ this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig); } + public SConv2D(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, INDArray bias, Conv2DConfig config) { + this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, config); + } + public SConv2D() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java index 5ae281ae2..700824512 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SpaceToDepth.java @@ -16,16 +16,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.enums.DataFormat; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.enums.DataFormat; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -45,47 +44,48 @@ import java.util.*; * @author raver119@gmail.com, Max Pumperla */ public class SpaceToDepth extends DynamicCustomOp { - private String dataFormat; + private DataFormat dataFormat; private int blockSize; public SpaceToDepth() { } - public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { + public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) { super(null, sameDiff, args, false); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - public SpaceToDepth(INDArray in, INDArray out, int blockSize, String dataFormat){ + public SpaceToDepth(SameDiff sameDiff, SDVariable x, int blockSize, DataFormat dataFormat) { + this(sameDiff, new SDVariable[]{x}, blockSize, dataFormat); + } + + public SpaceToDepth(INDArray in, INDArray out, int blockSize, DataFormat dataFormat){ super(null, in, out, null, null); this.blockSize = blockSize; this.dataFormat = dataFormat; - boolean isNHWC = dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } - - - public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) { - this(x, null, blockSize,dataFormat.toString()); + public SpaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) { + this(x, null, blockSize, dataFormat); } - @Override public List doDiff(List i_v) { // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format. SDVariable gradient = i_v.get(0); - SDVariable ret = sameDiff.cnn().depthToSpace(gradient, blockSize, dataFormat); + SDVariable ret = new DepthToSpace(sameDiff, gradient, blockSize, dataFormat).outputVariable(); return Arrays.asList(ret); } @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); + boolean isNHWC = dataFormat == null ? true : dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index 574682a36..df345a2f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -56,6 +56,14 @@ public class Upsampling2d extends DynamicCustomOp { addIArgument(nchw ? 1 : 0); } + public Upsampling2d(SameDiff sameDiff, SDVariable input, int scaleH, int scaleW, boolean nchw) { + this(sameDiff, input, nchw, scaleH, scaleW); + } + + public Upsampling2d(SameDiff sameDiff, SDVariable input, int scale) { + super(null,sameDiff, new SDVariable[]{input}); + addIArgument(scale); + } public Upsampling2d(INDArray input, int scale) { this(input, scale, scale, true); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index 3a3caa787..adc59e4e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -38,6 +38,11 @@ public class AbsoluteDifferenceLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public AbsoluteDifferenceLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, label); + } + public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java index 9794c7c8b..3f890da0b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/BaseLoss.java @@ -33,9 +33,9 @@ public abstract class BaseLoss extends DynamicCustomOp { protected LossReduce lossReduce; - public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, @NonNull SDVariable weights, + public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, SDVariable weights, @NonNull SDVariable labels){ - super(null, sameDiff, new SDVariable[]{predictions, weights, labels}); + super(null, sameDiff, new SDVariable[]{predictions, getWeights(sameDiff, weights, predictions), labels}); this.lossReduce = lossReduce; addArgs(); } @@ -50,6 +50,10 @@ public abstract class BaseLoss extends DynamicCustomOp { return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0); } + protected static SDVariable getWeights(SameDiff sd, SDVariable weights, SDVariable predictions){ + return weights != null ? weights : sd.constant(Nd4j.scalar(predictions.dataType(), 1.0)); + } + protected BaseLoss(){ } protected void addArgs(){ @@ -62,7 +66,7 @@ public abstract class BaseLoss extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2, "Expected exactly 2 or more input datatypes for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 241404492..7faa5f6b0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -39,6 +39,11 @@ public class CosineDistanceLoss extends BaseLoss { this.addIArgument(dimension); } + public CosineDistanceLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, int dimension) { + this(sameDiff, lossReduce, predictions, weights, labels, dimension); + } + public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){ super(lossReduce, predictions, weights, labels); this.dimension = dimension; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index f2998064f..5d85e4933 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -36,6 +36,11 @@ public class HingeLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public HingeLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index 18803cd9f..f08d90566 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -41,6 +41,11 @@ public class HuberLoss extends BaseLoss { tArguments.add(delta); } + public HuberLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double delta) { + this(sameDiff, lossReduce, predictions, weights, labels, delta); + } + public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){ super(lossReduce, predictions, weights, labels); this.delta = delta; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index 01aa283ed..a7a15f1b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -41,6 +41,11 @@ public class LogLoss extends BaseLoss { addTArgument(epsilon); } + public LogLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, double epsilon) { + this(sameDiff, lossReduce, predictions, weights, labels, epsilon); + } + public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){ super(lossReduce, predictions, weights, labels); this.epsilon = epsilon; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index 0e0d4f7dd..a893e3f4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -38,6 +38,11 @@ public class LogPoissonLoss extends BaseLoss { this(sameDiff, lossReduce, predictions, weights, labels, false); } + public LogPoissonLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce, boolean full) { + this(sameDiff, lossReduce, predictions, weights, labels, full); + } + public LogPoissonLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels, boolean full){ super(sameDiff, lossReduce, predictions, weights, labels); this.full = full; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 8e7bb9276..6c3c5d01b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -34,6 +34,11 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanPairwiseSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, + SDVariable weights, LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index c38faf29a..a9cf27584 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -36,6 +36,11 @@ public class MeanSquaredErrorLoss extends BaseLoss { super(sameDiff, lossReduce, predictions, weights, labels); } + public MeanSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights, + LossReduce lossReduce) { + this(sameDiff, lossReduce, predictions, weights, labels); + } + public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ super(lossReduce, predictions, weights, labels); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 32b176cfd..214380a8c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -44,6 +44,11 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { public static final double DEFAULT_LABEL_SMOOTHING = 0.0; private double labelSmoothing = 0.0; + public SigmoidCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, SDVariable weights, + LossReduce lossReduce, double labelSmoothing) { + this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing); + } + public SigmoidCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels, double labelSmoothing) { super(sameDiff, lossReduce, logits, weights, labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index c8a40b805..57576b78f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -45,6 +45,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { private double labelSmoothing = 0.0; + public SoftmaxCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, + SDVariable weights, LossReduce lossReduce, double labelSmoothing) { + this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing); + } + public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels, double labelSmoothing) { super(sameDiff, lossReduce, logits, weights, labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index ba3d53e45..d22478e71 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -93,6 +93,24 @@ public class Mmul extends DynamicCustomOp { } } + public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) { + addInputArgument(x, y); + addIArgument(ArrayUtil.fromBoolean(transposeX), + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); + } + + public Mmul(INDArray x, INDArray y) { + this(x,y,null,null); + } + + public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + super(null,sameDiff,new SDVariable[]{x,y}); + addIArgument(ArrayUtil.fromBoolean(transposeX), + ArrayUtil.fromBoolean(transposeY), + ArrayUtil.fromBoolean(transposeZ)); + } public Mmul() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index eca14e9f4..c613f107f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -77,6 +77,18 @@ public class TensorMmul extends DynamicCustomOp { addIArgument(dimensions[1]); } + public TensorMmul(SameDiff sameDiff, SDVariable x, SDVariable y, int[] dimensionsX, + int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null, sameDiff, new SDVariable[]{x,y}); + this.sameDiff = sameDiff; + this.axes = new int[][]{dimensionsX, dimensionsY}; + addIArgument(dimensionsX.length); + addIArgument(dimensionsX[0]); + addIArgument(dimensionsY.length); + addIArgument(dimensionsY[0]); + addBArgument(transposeX, transposeY, transposeZ); + } + @Override public List calculateOutputShape() { List ret = new ArrayList<>(1); @@ -242,6 +254,13 @@ public class TensorMmul extends DynamicCustomOp { this.axes = axes; } + public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + super(null,new INDArray[]{x, y},null); + this.axes = new int[][]{dimensionsX, dimensionsY}; + addBArgument(transposeX, transposeY, transposeZ); + } + @Override public String opName() { return "tensordot"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index 7daebd4cf..d4522ca69 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -41,6 +41,10 @@ public class Any extends BaseReduceBoolOp { super(x); } + public Any(INDArray x, int... dimensions) { + super(x, dimensions); + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 44cef710f..26eabf0ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -45,6 +45,10 @@ public class LogSumExp extends DynamicCustomOp { this.keepDims = keepDims; } + public LogSumExp(SameDiff sameDiff, SDVariable i_v, int[] dimensions) { + this(sameDiff, i_v, false, dimensions); + } + public LogSumExp() {} public LogSumExp(INDArray x, int... dimensions) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index e6b2b064d..b11fe5b1f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -41,6 +41,10 @@ public class SquaredNorm extends BaseReduceFloatOp { super(input, output, keepDims, dimensions); } + public SquaredNorm(INDArray input, boolean keepDims, int... dimensions){ + this(input, null, keepDims, dimensions); + } + public SquaredNorm(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java index f366dc0cd..0fb4db830 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/MatchCondition.java @@ -38,6 +38,10 @@ public class MatchCondition extends BaseReduceLongOp { private double eps; private int mode; + public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition) { + this(sameDiff, in, condition, false, null); + } + public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { super(sameDiff, in, dimensions, keepDims); this.compare = condition.getValue(); @@ -64,6 +68,10 @@ public class MatchCondition extends BaseReduceLongOp { defineDimensions(dimensions); } + public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) { + this(in, condition, dimensions); + } + @Override public int opNum() { return 2; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index ae70e44d9..859b89dac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -56,6 +56,10 @@ public class Sum extends BaseReduceSameOp { super(x, z, keepDims, dimensions); } + public Sum(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index b9a98dc6e..000b0414c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -50,6 +50,10 @@ public class LeakyReLU extends BaseScalarOp { } + public LeakyReLU(SameDiff sameDiff, SDVariable i_v, double alpha) { + this(sameDiff, i_v, false, alpha); + } + public LeakyReLU(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) { super(sameDiff, i_v, alpha, extraArgs); this.alpha = alpha; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 572e22087..5cfab3768 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -42,6 +42,10 @@ public class Pow extends BaseScalarOp { this.extraArgs = new Object[]{pow}; } + public Pow(SameDiff sameDiff, SDVariable i_v, double pow) { + this(sameDiff, i_v, false, pow); + } + public Pow(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double pow) { super(sameDiff, i_v, pow, extraArgs); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index ca8cee2f1..944d4d095 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -35,6 +35,10 @@ public class RectifiedLinear extends BaseScalarOp { super(sameDiff, i_v, cutoff, inPlace); } + public RectifiedLinear(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public RectifiedLinear() { super(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java index f593dc663..c80d3c8f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java @@ -42,6 +42,10 @@ public class Relu6 extends BaseScalarOp { super(sameDiff, i_v, cutoff, inPlace); } + public Relu6(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public Relu6() { // } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java index 98e08b010..65f653d64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java @@ -41,6 +41,10 @@ public class Step extends BaseScalarOp { this.extraArgs = new Object[] {cutoff}; } + public Step(SameDiff sameDiff, SDVariable i_v, double cutoff) { + this(sameDiff, i_v, false, cutoff); + } + public Step() { cutoff = 0.0; this.extraArgs = new Object[] {cutoff}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 412b024f1..6f72490a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -46,6 +46,9 @@ public class ScalarLessThan extends BaseScalarBoolOp { super(sameDiff, i_v, scalar, inPlace); } + public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, double scalar) { + super(sameDiff, i_v, scalar, false); + } @Override public int opNum() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 73f74665a..160556867 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterAdd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterAdd(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 4e7563e4a..5d6b60c88 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterDiv extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterDiv() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 65162aad3..7f814d928 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +42,10 @@ public class ScatterMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMax(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMax() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 8d8fe4e33..2539a3d56 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -41,6 +42,10 @@ public class ScatterMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMin(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMin() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 2790667cd..411c59188 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterMul extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterMul(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterMul() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 382806779..83c4cc222 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -43,6 +44,10 @@ public class ScatterSub extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterSub(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterSub() {} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index dd9c52891..93e1e5995 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +54,10 @@ public class ScatterUpdate extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); } + public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) { + addInputArgument(ref, indices, updates); + } + public ScatterUpdate(){} @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 69a62e493..bddcef970 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -49,6 +49,14 @@ public class Concat extends DynamicCustomOp { addIArgument(concatDimension); } + public Concat(INDArray[] arrays, int concatDimension) { + this(concatDimension, arrays); + } + + public Concat(SameDiff sameDiff, SDVariable[] inputs, int concatDimension){ + this(sameDiff, concatDimension, inputs); + } + public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){ super(null, sameDiff, inputs); addIArgument(concatDimension); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java index 337c1a936..2bf94021a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ConfusionMatrix.java @@ -68,6 +68,12 @@ public class ConfusionMatrix extends DynamicCustomOp { } } + + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, DataType dataType){ + this(sameDiff, labels, pred, weights); + this.outputType = dataType; + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ super(null, sameDiff, new SDVariable[]{labels, pred}); this.outputType = dataType; @@ -82,6 +88,11 @@ public class ConfusionMatrix extends DynamicCustomOp { addIArgument(numClasses); } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, Integer numClasses){ + super(null, sameDiff, new SDVariable[]{labels, pred, weights}); + addIArgument(numClasses); + } + public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){ super(null, sameDiff, new SDVariable[]{labels, pred, weights}); if(numClasses != null) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java index 3e94cb126..616d4d1aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Cross.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -39,15 +40,17 @@ import java.util.List; * * @author Max Pumperla */ +@NoArgsConstructor public class Cross extends DynamicCustomOp { - public Cross() { - } - public Cross(SameDiff sameDiff, SDVariable[] args) { super(null, sameDiff, args, false); } + public Cross(SameDiff sameDiff, SDVariable a, SDVariable b) { + this(sameDiff, new SDVariable[]{a,b}); + } + public Cross(INDArray a, INDArray b){ this(a,b,null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java index 54cefce73..94516ec54 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NoArgsConstructor; import lombok.NonNull; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; @@ -39,11 +40,9 @@ import java.util.Map; * * @author Max Pumperla */ +@NoArgsConstructor public class Diag extends DynamicCustomOp { - public Diag() { - } - public Diag(@NonNull INDArray input) { this(input, null); } @@ -52,6 +51,10 @@ public class Diag extends DynamicCustomOp { super(null, new INDArray[]{input}, wrapOrNull(output)); } + public Diag(SameDiff sameDiff, SDVariable input) { + this(sameDiff, new SDVariable[]{input}, false); + } + public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java index 9162b8935..b498157b6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java @@ -50,6 +50,10 @@ public class DiagPart extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public DiagPart(SameDiff sameDiff, SDVariable in) { + this(sameDiff, new SDVariable[]{in}, false); + } + public DiagPart(INDArray in){ this(in, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index a13a03184..f0a6f436a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -46,6 +46,10 @@ public class ExpandDims extends DynamicCustomOp { public ExpandDims() { } + public ExpandDims(SameDiff sameDiff, SDVariable args, int axis) { + this(sameDiff, new SDVariable[]{args}, axis); + } + public ExpandDims(SameDiff sameDiff, SDVariable[] args, int axis) { super(null, sameDiff, args); if (axis == Integer.MAX_VALUE) { @@ -63,6 +67,11 @@ public class ExpandDims extends DynamicCustomOp { super(null, inputs, outputs); } + public ExpandDims(INDArray input, int axis) { + addInputArgument(input); + addIArgument(axis); + } + public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index 3a8bb8f15..1e8409244 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -122,6 +122,13 @@ public class Eye extends DynamicCustomOp { addArgs(); } + public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols, DataType dataType, int[] batchDimension) { + super(null, sameDiff, new SDVariable[] {numRows, numCols}, false); + this.batchDimension = batchDimension; + this.dataType = dataType; + addArgs(); + } + protected void addArgs() { iArguments.clear(); tArguments.clear(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index fd6ec5240..b4f690ba5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -40,6 +41,13 @@ public class Gather extends DynamicCustomOp { protected int[] indices; protected int jaxis = 0; + public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) { + this(sameDiff, df, indices, axis, false); + } + + public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) { + this(sameDiff, df, indices, axis, false); + } public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); @@ -56,6 +64,21 @@ public class Gather extends DynamicCustomOp { this.jaxis = axis; } + public Gather(INDArray df, int[] indexes, int axis) { + addInputArgument(df); + addIArgument(axis); + addIArgument(indexes); + this.jaxis = axis; + this.indices = indices; + } + + public Gather(INDArray df, INDArray indexes, int axis) { + addInputArgument(df, indexes); + addIArgument(axis); + this.jaxis = axis; + this.indices = indices; + } + @Override public String onnxName() { return "Gather"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index b8ef51d57..a239bd9ec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -17,10 +17,13 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; +import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.util.ArrayUtil; import java.util.Collections; import java.util.List; @@ -31,11 +34,19 @@ import java.util.List; @NoArgsConstructor public class GatherNd extends DynamicCustomOp { + public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) { + super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false); + } public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); } + public GatherNd(INDArray[] df, INDArray[] indices) { + addInputArgument(df); + addInputArgument(indices); + } + @Override public String opName() { return "gather_nd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index fab4a0066..6fca99eae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -39,11 +40,24 @@ public class Linspace extends DynamicCustomOp { private DataType dataType; + public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { + super(sameDiff, new SDVariable[0]); + addTArgument(start,stop); + addIArgument(number); + addDArgument(dataType); + } + public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ super(sameDiff, new SDVariable[]{from, to, length}); this.dataType = dataType; } + public Linspace(DataType dataType, double start, double stop, long number) { + addDArgument(dataType); + addTArgument(start, stop); + addIArgument(number); + } + public Linspace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java index 84eb47fc8..f2c11f1ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MeshGrid.java @@ -37,6 +37,10 @@ public class MeshGrid extends DynamicCustomOp { addIArgument(cartesian ? 1 : 0); } + public MeshGrid(SameDiff sd, SDVariable[] inputs, boolean cartesian) { + this(sd, cartesian, inputs); + } + public MeshGrid(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java index beb9d09b9..affc603e9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OneHot.java @@ -66,6 +66,11 @@ public class OneHot extends DynamicCustomOp { this(indices, output, depth, -1, 1, 0); } + public OneHot(INDArray indices, int depth) { + addInputArgument(indices); + addIArgument(depth); + } + public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { super(null, indices, output, null, null); this.depth = depth; @@ -75,6 +80,12 @@ public class OneHot extends DynamicCustomOp { addArgs(); } + public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { + addInputArgument(indices); + addIArgument(depth, axis); + addTArgument(on, off); + addDArgument(dataType); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java index 4b4b3e578..8da18be06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java @@ -48,10 +48,18 @@ public class OnesLike extends DynamicCustomOp { public OnesLike() { } + public OnesLike(SameDiff sameDiff, SDVariable input) { + this(null, sameDiff, input); + } + public OnesLike(String name, SameDiff sameDiff, SDVariable input) { this(name, sameDiff, input, input.dataType()); } + public OnesLike(SameDiff sameDiff, SDVariable input, DataType dataType) { + this(null, sameDiff, input, dataType); + } + public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { super(name, sameDiff, new SDVariable[]{input}, false); this.outputType = dataType; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index cd78e5d12..cfd0bd7ed 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -55,6 +55,11 @@ public class Permute extends Transpose { addIArgument(permuteDims); } + public Permute(INDArray input, int... permuteDims){ + addInputArgument(input); + addIArgument(permuteDims); + } + public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ super(sd, input, permuteDims); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index 8201e6075..5f1448d06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -39,10 +40,18 @@ public class Rank extends DynamicCustomOp { public Rank() { } + public Rank(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public Rank(SameDiff sameDiff, SDVariable input, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); } + public Rank(INDArray indArray) { + addInputArgument(indArray); + } + @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 44d9b79fe..ddf0224db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -59,6 +59,10 @@ public class Reshape extends DynamicCustomOp { super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List)null); } + public Reshape(INDArray in, INDArray shape) { + addInputArgument(in, shape); + } + public Reshape() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index cc5b28bba..a2ca91c65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -69,7 +69,13 @@ public class SequenceMask extends DynamicCustomOp { addIArgument(maxLen); this.dataType = dataType; addDArgument(dataType); - } + } + + public SequenceMask(INDArray input, DataType dataType) { + addInputArgument(input); + this.dataType = dataType; + addDArgument(dataType); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java index 6cd2eec06..62bc5714e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java @@ -48,6 +48,10 @@ public class Shape extends DynamicCustomOp { public Shape() {} + public Shape(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public Shape(SameDiff sameDiff, SDVariable input, boolean inPlace) { super(null, sameDiff, new SDVariable[] {input}, inPlace); } @@ -56,6 +60,10 @@ public class Shape extends DynamicCustomOp { super(null, in, out, null, null); } + public Shape(INDArray in){ + this(in, null); + } + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java index 27989e878..ce3ce9cae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.tensorflow.framework.AttrValue; @@ -47,6 +48,11 @@ public class Size extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input}, false); } + public Size(INDArray in) { + addInputArgument(in); + } + + @Override public String onnxName() { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index 379e6515e..c5f7cdd70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -52,6 +53,11 @@ public class Slice extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{input, begin, end}); } + public Slice(INDArray in, int[] begin, int... size) { + addInputArgument(in); + addIArgument(begin); + addIArgument(size); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java index 1ffd0820b..2734d68b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -36,12 +37,21 @@ public class Squeeze extends DynamicCustomOp { public Squeeze() { } + public Squeeze(SameDiff sameDiff, SDVariable arg, int squeezeDims) { + this(sameDiff, arg, new int[] {squeezeDims}); + } + public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) { super(null, sameDiff, new SDVariable[]{arg}); this.squeezeDims = squeezeDims; addIArgument(squeezeDims); } + public Squeeze(INDArray x, int axis) { + addInputArgument(x); + addIArgument(axis); + } + @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { nodeDef.getAttrMap().get("squeeze_dims"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index d2bf9d71b..89c459be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -50,6 +50,16 @@ public class Stack extends DynamicCustomOp { addArgs(); } + public Stack(INDArray input, int axis) { + addInputArgument(input); + this.jaxis = axis; + addArgs(); + } + + public Stack(SameDiff sameDiff, SDVariable values, int axis) { + this(sameDiff, new SDVariable[]{values}, axis); + } + public Stack(SameDiff sameDiff, SDVariable[] values, int axis) { super(null, sameDiff, values, false); this.jaxis = axis; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 2208d3a36..a053403af 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -25,6 +25,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.util.ArrayUtil; @@ -95,6 +96,20 @@ public class StridedSlice extends DynamicCustomOp { } + public StridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + addInputArgument(in); + this.begin = ArrayUtil.toLongArray(begin); + this.end = ArrayUtil.toLongArray(end); + this.strides = ArrayUtil.toLongArray(strides); + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + addArguments(); + } + private void addArguments(){ addIArgument(beginMask); addIArgument(ellipsisMask); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index e1fb02be9..c2e476f60 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -66,6 +66,14 @@ public class Tile extends DynamicCustomOp { this(inputs,outputs,axis,false); } + public Tile(INDArray x, INDArray repeat) { + addInputArgument(x, repeat); + } + + public Tile(INDArray x, int... repeat) { + addInputArgument(x); + addIArgument(repeat); + } public Tile() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 9ab0ad58c..95215b686 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -60,6 +60,10 @@ public class Transpose extends DynamicCustomOp { super(null, new INDArray[]{input}, result == null ? null : new INDArray[]{result}, null, (List) null); } + public Transpose(INDArray input) { + addInputArgument(input); + } + public Transpose() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java index 7225ac355..d71200e4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ZerosLike.java @@ -45,6 +45,10 @@ public class ZerosLike extends DynamicCustomOp { protected DataType outputType; //Allow customizing dtype for TF import + public ZerosLike(SameDiff sameDiff, SDVariable input) { + this(null, sameDiff, input, false, input.dataType()); + } + public ZerosLike(String name, SameDiff sameDiff, SDVariable input) { this(name, sameDiff, input, false, input.dataType()); } @@ -66,6 +70,10 @@ public class ZerosLike extends DynamicCustomOp { this(in, out, in.dataType()); } + public ZerosLike(INDArray in){ + addInputArgument(in); + } + public ZerosLike(INDArray in, INDArray out, DataType dataType) { super(null, in, out, null, null); if (dataType != null) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index b44b11cf6..adc92549b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -67,6 +67,10 @@ public class Variance extends BaseReduceOp { this.biasCorrected = biasCorrected; } + public Variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { + this(x, null, biasCorrected, keepDims, dimensions); + } + public Variance(INDArray x, boolean biasCorrected, int... dimensions) { super(x); this.biasCorrected = biasCorrected; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java index 75a250049..dd45af037 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Cholesky.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.ops.SDValidation; import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -34,8 +37,17 @@ import java.util.Map; * Cholesky op wrapper * @author raver119@gmail.com */ +@NoArgsConstructor public class Cholesky extends DynamicCustomOp { + public Cholesky(INDArray input) { + addInputArgument(input); + } + + public Cholesky(SameDiff sameDiff, SDVariable sdInput) { + super(sameDiff, new SDVariable[]{sdInput}); + } + @Override public String opName() { return "cholesky"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index b07f52ce1..8d0a9d0d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -54,6 +54,10 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(SameDiff sd, SDVariable in, SDVariable padding, double padValue) { + this(sd, in, padding, Mode.CONSTANT, padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 03256a81a..8df844943 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsFinite extends BaseTransformBoolOp { public IsFinite(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsFinite() {} + public IsFinite(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsFinite(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java index efefaa1d9..44cb362a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsInf extends BaseTransformBoolOp { public IsInf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsInf() {} + public IsInf(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsInf(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index 206bc32b3..daf9b0ea3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,12 +32,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class IsNaN extends BaseTransformBoolOp { public IsNaN(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public IsNaN() {} + public IsNaN(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public IsNaN(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java index 89a9ddc64..78b995669 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java @@ -53,6 +53,14 @@ public class BatchToSpace extends DynamicCustomOp { public BatchToSpace() { } + public BatchToSpace(SameDiff sameDiff, SDVariable x, int[] blocks, int[] croppingTop, int... croppingBottom) { + this(sameDiff, x, blocks, new int[][]{croppingTop, croppingBottom}, false); + } + + public BatchToSpace(SameDiff sameDiff, SDVariable x, int[] blocks, int[][] crops, boolean inPlace) { + this(sameDiff, new SDVariable[]{x}, blocks, crops, inPlace); + } + public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) { super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace); @@ -63,15 +71,14 @@ public class BatchToSpace extends DynamicCustomOp { addIArgument(b); } - public BatchToSpace(INDArray x, int[] blocks, int[] croppingTop, int[] croppingBottom) { - super(null,x,null,null,null); + public BatchToSpace(INDArray x, int[] blocks, int[] croppingTop, int... croppingBottom) { + addInputArgument(x); + int[][] crops = new int[][]{croppingTop, croppingBottom}; this.blocks = blocks; - this.crops = new int[][]{croppingTop,croppingBottom}; + this.crops = crops; + for (val b : blocks) addIArgument(b); - - for (int e = 0; e < crops.length; e++) - addIArgument(crops[e][0], crops[e][1]); } @@ -94,7 +101,7 @@ public class BatchToSpace extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops)); + return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java index ef07c7cc6..6622b134e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java @@ -83,7 +83,7 @@ public class BatchToSpaceND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops)); + return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 3874c040b..0be0b08ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -59,7 +59,7 @@ public class CumProd extends DynamicCustomOp { } public CumProd(INDArray in, INDArray result, boolean exclusive, boolean reverse, int... axis) { - super(null, new INDArray[]{in}, new INDArray[]{result}, null, (List)null); + super(null, new INDArray[]{in}, result != null ? new INDArray[]{result} : null, null, (List)null); this.exclusive = exclusive; this.reverse = reverse; this.jaxis = axis; @@ -69,6 +69,10 @@ public class CumProd extends DynamicCustomOp { addArgs(); } + public CumProd(INDArray in, boolean exclusive, boolean reverse, int... axis) { + this(in, null, exclusive, reverse, axis); + } + @Override public String opName() { return "cumprod"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index 6720b5a75..c24693b01 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -64,13 +64,16 @@ public class CumSum extends DynamicCustomOp { } public CumSum(INDArray in, INDArray result, boolean exclusive, boolean reverse, int... axis) { - super(null, new INDArray[]{in}, new INDArray[]{result}, null, (List)null); + super(null, new INDArray[]{in}, wrapOrNull(result), null, (List)null); this.exclusive = exclusive; this.reverse = reverse; this.jaxis = axis; addArgs(); } + public CumSum(INDArray in, boolean exclusive, boolean reverse, int... axis) { + this(in, null, exclusive, reverse, axis); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java index c4de19cfa..3bc812596 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Dilation2D.java @@ -16,8 +16,6 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; -import lombok.NoArgsConstructor; -import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -42,7 +40,6 @@ import java.util.*; * * @author raver119@gmail.com */ -@NoArgsConstructor public class Dilation2D extends DynamicCustomOp { protected boolean isSameMode; @@ -52,11 +49,21 @@ public class Dilation2D extends DynamicCustomOp { // strides protected int s0, s1, s2, s3; + + public Dilation2D() { + } + + public Dilation2D(SameDiff sameDiff, SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) { + this(sameDiff, new SDVariable[]{df, weights}, strides, rates, isSameMode, false); + } + public Dilation2D(SameDiff sameDiff, SDVariable[] inputAndWeights, int[] strides, int[] rates, boolean isSameMode, boolean inPlace ) { super(null, sameDiff, inputAndWeights, inPlace); - Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); - Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); + Preconditions.checkArgument(rates.length == 4, + "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); + Preconditions.checkArgument(strides.length == 4, + "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); r0 = rates[0]; r1 = rates[1]; @@ -69,18 +76,21 @@ public class Dilation2D extends DynamicCustomOp { this.isSameMode = isSameMode; addArgs(); + } public Dilation2D(INDArray[] inputArrays, INDArray[] outputs) { super(null, inputArrays, outputs); + } - public Dilation2D(@NonNull INDArray df, @NonNull INDArray weights, @NonNull int[] strides, @NonNull int[] rates, boolean isSameMode) { - super(null, new INDArray[]{df, weights},null); - Preconditions.checkArgument(rates.length == 4, "Dilation rate length must be 4, got an array with length %s with values %s", rates.length, rates); - Preconditions.checkArgument(strides.length == 4, "Dilation strides length must be 4, got an array with length %s with values %s", strides.length, strides); + public Dilation2D(INDArray df, INDArray weights, int[] strides, int[] rates, boolean isSameMode) { + addInputArgument(df, weights); - this.isSameMode = isSameMode; + if (rates.length < 4) + throw new IllegalArgumentException("Dilation rate length must be 4."); + if (strides.length < 4) + throw new IllegalArgumentException("Strides length must be 4."); r0 = rates[0]; r1 = rates[1]; @@ -90,10 +100,11 @@ public class Dilation2D extends DynamicCustomOp { s1 = strides[1]; s2 = strides[2]; s3 = strides[3]; + this.isSameMode = isSameMode; + addArgs(); } - protected void addArgs() { addIArgument(isSameMode ? 1 : 0, r0, r1, r2, r3, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 0e5232896..b64581b49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -53,6 +54,10 @@ public class DynamicPartition extends DynamicCustomOp { public DynamicPartition() { } + public DynamicPartition(SameDiff sameDiff, SDVariable input, SDVariable[] partitions, int numPartitions) { + this(sameDiff, input, partitions[0], numPartitions); + } + public DynamicPartition(SameDiff sameDiff, SDVariable input, SDVariable partitions, int numPartitions) { super(null, sameDiff, new SDVariable[] {input, partitions}, false); @@ -61,6 +66,14 @@ public class DynamicPartition extends DynamicCustomOp { addArgs(); } + public DynamicPartition(INDArray input, INDArray[] partitions, int numPartitions) { + addInputArgument(input); + for (INDArray part : partitions) + addInputArgument(part); + + addIArgument(numPartitions); + } + @Override public List doDiff(List i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 72aebe1e2..94c34d108 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -60,6 +61,16 @@ public class DynamicStitch extends DynamicCustomOp { this.numPartitions = inputs.length; } + public DynamicStitch(INDArray[] inputs, INDArray[] indices) { + for (INDArray input : inputs) { + addInputArgument(input); + } + + for (INDArray index : indices) { + addInputArgument(index); + } + } + @Override public List doDiff(List i_v) { // DynamicPartition and DynamicStitch are mutually inverse diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java index c4d7b2469..0d1214c9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/EqualTo.java @@ -36,6 +36,10 @@ import java.util.List; public class EqualTo extends BaseDynamicTransformOp { public EqualTo() {} + public EqualTo( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public EqualTo( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class EqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public EqualTo( INDArray x, INDArray y) { + addInputArgument(x, y); + } + public EqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index a5ffbced5..73f221f35 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -63,6 +63,11 @@ public class Fill extends DynamicCustomOp { this.value = value; } + public Fill(INDArray shape, DataType dataType, double value) { + super(null, shape, null, Collections.singletonList(value), null); + this.value = value; + } + public Fill(INDArray shape, INDArray value, INDArray result) { super(null, new INDArray[]{shape, value}, new INDArray[]{result}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java index 4c7fce72c..6a1ecc2cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThan.java @@ -37,6 +37,10 @@ import java.util.List; public class GreaterThan extends BaseDynamicTransformOp { public GreaterThan() {} + public GreaterThan( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y},false); + } + public GreaterThan( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -45,6 +49,10 @@ public class GreaterThan extends BaseDynamicTransformOp { super(inputs, outputs); } + public GreaterThan( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public GreaterThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java index 6326870ec..dfb7fe8dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/GreaterThanOrEqual.java @@ -36,6 +36,10 @@ import java.util.List; public class GreaterThanOrEqual extends BaseDynamicTransformOp { public GreaterThanOrEqual() {} + public GreaterThanOrEqual( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public GreaterThanOrEqual( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -48,6 +52,11 @@ public class GreaterThanOrEqual extends BaseDynamicTransformOp { this(new INDArray[]{x, y}, new INDArray[]{z}); } + public GreaterThanOrEqual(INDArray x, INDArray y) { + + this(new INDArray[]{x,y}, null); + } + @Override public int opNum() { return 11; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 387f484f2..6048c9dff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; @@ -35,10 +36,18 @@ import java.util.List; @NoArgsConstructor public class InvertPermutation extends BaseDynamicTransformOp { + public InvertPermutation(SameDiff sameDiff, SDVariable input) { + this(sameDiff, input, false); + } + public InvertPermutation(SameDiff sameDiff, SDVariable input, boolean inPlace) { super( sameDiff, new SDVariable[] {input}, inPlace); } + public InvertPermutation(INDArray input) { + addInputArgument(input); + } + @Override public String opName() { return "invert_permutation"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java index 96ad104af..95640fead 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNonDecreasing.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -32,13 +33,21 @@ import java.util.List; * and returns true if for every adjacent pair we have x[i] <= x[i+1]. * */ +@NoArgsConstructor public class IsNonDecreasing extends DynamicCustomOp { - public IsNonDecreasing() {} public IsNonDecreasing( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } + public IsNonDecreasing( SameDiff sameDiff, SDVariable[] args) { + super(null, sameDiff, args, false); + } + + public IsNonDecreasing( SameDiff sameDiff, SDVariable input) { + super(null, sameDiff, new SDVariable[]{input}, false); + } + public IsNonDecreasing(@NonNull INDArray input){ this(input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java index f25372b58..88c0a84ba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsNumericTensor.java @@ -33,6 +33,10 @@ import java.util.List; public class IsNumericTensor extends DynamicCustomOp { public IsNumericTensor() {} + public IsNumericTensor( SameDiff sameDiff, SDVariable args) { + this(sameDiff, new SDVariable[]{args}, false); + } + public IsNumericTensor( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(null, sameDiff, args, inPlace); } @@ -41,6 +45,9 @@ public class IsNumericTensor extends DynamicCustomOp { super(null, inputs, outputs); } + public IsNumericTensor(INDArray input) { + addInputArgument(input); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java index 55b866cad..f6701c4ca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/IsStrictlyIncreasing.java @@ -39,6 +39,10 @@ public class IsStrictlyIncreasing extends DynamicCustomOp { super(null, sameDiff, args, inPlace); } + public IsStrictlyIncreasing( SameDiff sameDiff, SDVariable input) { + super(null, sameDiff, new SDVariable[]{input}); + } + public IsStrictlyIncreasing(@NonNull INDArray input){ this(input, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java index 61fbe2bee..b1a38e0ff 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThan.java @@ -37,6 +37,10 @@ import java.util.List; public class LessThan extends BaseDynamicTransformOp { public LessThan() {} + public LessThan( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public LessThan( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -45,6 +49,10 @@ public class LessThan extends BaseDynamicTransformOp { super(inputs, outputs); } + public LessThan( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public LessThan(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java index 9f471f8dc..0ca6bf7e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LessThanOrEqual.java @@ -36,6 +36,10 @@ import java.util.List; public class LessThanOrEqual extends BaseDynamicTransformOp { public LessThanOrEqual() {} + public LessThanOrEqual( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public LessThanOrEqual( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class LessThanOrEqual extends BaseDynamicTransformOp { super(inputs, outputs); } + public LessThanOrEqual( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public LessThanOrEqual(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java index 7fd707507..37fe652d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixDeterminant.java @@ -48,6 +48,10 @@ public class MatrixDeterminant extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in}, inPlace); } + public MatrixDeterminant(SameDiff sameDiff, SDVariable in) { + this(sameDiff, in, false); + } + @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java index 4ff0f942b..475f3c6a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixInverse.java @@ -46,6 +46,9 @@ public class MatrixInverse extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in}, inPlace); } + public MatrixInverse(SameDiff sameDiff, SDVariable in) { + this(sameDiff, in, false); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 9bbf6c50f..19d139cbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -34,6 +34,10 @@ public class MatrixSetDiag extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{in, diag}, inPlace); } + public MatrixSetDiag(SameDiff sameDiff, SDVariable in, SDVariable diag) { + this(sameDiff, in, diag, false); + } + public MatrixSetDiag(@NonNull INDArray in, @NonNull INDArray diag){ super(new INDArray[]{in, diag}, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index 6c877f96d..e8653d4c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -52,6 +52,10 @@ public class Max extends BaseDynamicTransformOp { super(inputs, outputs); } + public Max( INDArray x, INDArray y) { + addInputArgument(x,y); + } + @Override public String opName() { return "maximum"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java index 73bfbacc7..c195178c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Min.java @@ -52,6 +52,9 @@ public class Min extends BaseDynamicTransformOp { super(inputs, outputs); } + public Min( INDArray x, INDArray y) { + addInputArgument(x,y); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java index c2c245979..69d724a7e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/NotEqualTo.java @@ -36,6 +36,10 @@ import java.util.List; public class NotEqualTo extends BaseDynamicTransformOp { public NotEqualTo() {} + public NotEqualTo( SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, new SDVariable[]{x,y}, false); + } + public NotEqualTo( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } @@ -44,6 +48,10 @@ public class NotEqualTo extends BaseDynamicTransformOp { super(inputs, outputs); } + public NotEqualTo( INDArray x, INDArray y) { + addInputArgument(x,y); + } + public NotEqualTo(INDArray x, INDArray y, INDArray z){ this(new INDArray[]{x, y}, new INDArray[]{z}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java new file mode 100644 index 000000000..409b0bd8e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Qr.java @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.List; + +@NoArgsConstructor +public class Qr extends DynamicCustomOp { + + public Qr(INDArray input) { + this(input, false); + } + public Qr(INDArray input, boolean fullMatrices) { + addInputArgument(input); + addBArgument(fullMatrices); + } + + public Qr(SameDiff sameDiff, SDVariable input, boolean fullMatrices) { + super(sameDiff, new SDVariable[]{input}); + addBArgument(fullMatrices); + } + + @Override + public String opName() { + return "qr"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 3e7276def..11897fef8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -66,6 +67,11 @@ public class ReverseSequence extends DynamicCustomOp { public ReverseSequence() { } + public ReverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { + addInputArgument(x, seq_lengths); + addIArgument(seqDim, batchDim); + } + @Override public String opName() { return "reverse_sequence"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 712df46fc..24c2353c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -50,6 +50,10 @@ public class SoftMax extends BaseDynamicTransformOp { super(sameDiff, args, false); } + public SoftMax(SameDiff sameDiff, SDVariable x, int dimension) { + this(sameDiff, new SDVariable[]{x}, dimension); + } + public SoftMax(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java index 1ce8a7889..c7a8c0cda 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java @@ -54,6 +54,10 @@ public class SpaceToBatch extends DynamicCustomOp { public SpaceToBatch() { } + public SpaceToBatch(SameDiff sameDiff, SDVariable x, int[] blocks, int[] paddingTop, int... paddingBottom) { + this(sameDiff, new SDVariable[]{x}, blocks, new int[][]{paddingBottom, paddingBottom}, false); + } + public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) { super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace); @@ -63,19 +67,14 @@ public class SpaceToBatch extends DynamicCustomOp { addIArgument(blocks[0]); } - public SpaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int[] paddingBottom) { - super(null,x,null,null,null); + public SpaceToBatch(INDArray x, int[] blocks, int[] paddingTop, int... paddingBottom) { + addInputArgument(x); this.blocks = blocks; - this.padding = new int[][]{paddingTop,paddingBottom}; + this.padding = padding; - for (val b : blocks) - addIArgument(b); - - for (int e = 0; e < padding.length; e++) - addIArgument(padding[e][0], padding[e][1]); + addIArgument(blocks[0]); } - @Override public String opName() { return "space_to_batch"; @@ -95,7 +94,7 @@ public class SpaceToBatch extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding)); + return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java index 9eb72e54f..12009d955 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java @@ -84,7 +84,7 @@ public class SpaceToBatchND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding)); + return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java index c66285dc2..60de8a665 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Svd.java @@ -71,6 +71,12 @@ public class Svd extends DynamicCustomOp { addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum); } + public Svd(INDArray input, boolean fullUV, boolean computeUV, int switchNum) { + addInputArgument(input); + addBArgument(fullUV, computeUV); + addIArgument(switchNum); + } + @Override public String opName(){ return "svd"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index 06ebbb5ef..24d79f234 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -53,7 +53,7 @@ public class Trace extends DynamicCustomOp { public List doDiff(List gradAtOutput){ SDVariable rows = f().reshape(f().sizeAt(arg(), -2), new long[]{1}); SDVariable cols = f().reshape(f().sizeAt(arg(), -1), new long[]{1}); - SDVariable eye = sameDiff.math().eye(f().shape(gradAtOutput.get(0)), rows, cols); + SDVariable eye = sameDiff.math().eye(/*f().shape(gradAtOutput.get(0)),*/ rows, cols); //Reshape gradient from [x,y,z] to [x,y,z,1,1] SDVariable reshapedGrad = f().expandDims(gradAtOutput.get(0), -1); reshapedGrad = f().expandDims(reshapedGrad, -1); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index b0c007a29..5b6cd2517 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMax extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMax(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index 0b881ecbd..d0a9a6784 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMean extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMean(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMean(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index 7417ccb1d..2bc369f2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentMin extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentMin(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentMin(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 4345b27ec..3be3625e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentProd extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentProd(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentProd(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index 236a74041..5de847162 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -37,6 +38,10 @@ public class SegmentSum extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {data, segmentIds}, false); } + public SegmentSum(INDArray data, INDArray segmentIds) { + addInputArgument(data, segmentIds); + } + public SegmentSum(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java index 87e5281ea..df5cdbcc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.floating; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,13 +32,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class RSqrt extends BaseTransformFloatOp { + + public RSqrt(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RSqrt(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public RSqrt() {} - public RSqrt(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java index 454b27342..34d74beb8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Sqrt.java @@ -36,6 +36,10 @@ public class Sqrt extends BaseTransformFloatOp { super(sameDiff, i_v, inPlace); } + public Sqrt(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sqrt() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java index e5322c02f..2ca198506 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -36,12 +37,15 @@ import java.util.List; * @author Adam Gibson */ @Deprecated +@NoArgsConstructor public class HardTanhDerivative extends BaseTransformStrictOp { public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public HardTanhDerivative() {} + public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public HardTanhDerivative(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java index 202f7e291..259180f5c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUDerivative.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,6 +32,7 @@ import java.util.List; /**Leaky ReLU derivative. Default alpha = 0.01. Cutoff = 0 */ +@NoArgsConstructor public class LeakyReLUDerivative extends BaseScalarOp { private double alpha = 0.01; @@ -40,14 +42,16 @@ public class LeakyReLUDerivative extends BaseScalarOp { this.extraArgs = new Object[] {alpha}; } + public LeakyReLUDerivative(SameDiff sameDiff, SDVariable i_v, double alpha) { + this(sameDiff, i_v, false, alpha); + } + public LeakyReLUDerivative(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) { super(sameDiff, i_v, alpha, extraArgs); this.alpha = alpha; this.extraArgs = new Object[] {alpha}; } - public LeakyReLUDerivative() {} - public LeakyReLUDerivative(INDArray x, INDArray z) { this(x, z, 0.01); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java index 4ae26e585..cd189c82d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,14 @@ import java.util.List; * @deprecated Use {@link SoftSignBp} */ @Deprecated +@NoArgsConstructor public class SoftSignDerivative extends BaseTransformStrictOp { public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public SoftSignDerivative() { + public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public SoftSignDerivative(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index 07bea9ae7..fc89333f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -33,14 +34,17 @@ import java.util.List; * * @author Max Pumperla */ +@NoArgsConstructor public class MergeAddOp extends BaseDynamicTransformOp { - public MergeAddOp() {} - public MergeAddOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { super(sameDiff, args, inPlace); } + public MergeAddOp(SameDiff sameDiff, SDVariable[] args) { + this(sameDiff, args, false); + } + public MergeAddOp(@NonNull INDArray... inputs){ this(inputs, null); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 2a8cf1111..4c6bf0ad9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -32,13 +33,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Abs extends BaseTransformSameOp { + + public Abs(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Abs(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Abs() { - } public Abs(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index d58ad8f3f..6422e8df8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; import java.util.Collections; + +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,13 +34,17 @@ import java.util.List; * * @author Paul Dubs */ +@NoArgsConstructor public class Cube extends BaseTransformSameOp { + + public Cube(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Cube(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cube() {} - public Cube(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java index 842c78929..ba6dd7171 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -30,12 +31,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Floor extends BaseTransformSameOp { public Floor(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Floor() { + public Floor(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public Floor(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index bf3d28d71..dcee02131 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -41,6 +41,10 @@ public class Identity extends BaseDynamicTransformOp { super(new INDArray[]{x}, new INDArray[]{z}); } + public Identity(INDArray x){ + addInputArgument(x); + } + public Identity(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index 9a1664cb6..37b370fe9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -31,12 +32,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Negative extends BaseTransformSameOp { public Negative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Negative() {} + public Negative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public Negative(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java index 764aca29f..1e11fa34d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -30,13 +31,11 @@ import java.util.List; /** * Created by susaneraly on 3/28/18. */ +@NoArgsConstructor public class Reciprocal extends BaseTransformSameOp { - public Reciprocal(SameDiff sameDiff, SDVariable in, boolean inPlace) { - super(sameDiff, in, inPlace); - } - - public Reciprocal() { + public Reciprocal(SameDiff sameDiff, SDVariable in) { + super(sameDiff, in, false); } public Reciprocal(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index 25f3120dc..375a8acb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -31,13 +32,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Round extends BaseTransformSameOp { + + public Round(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Round(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Round() {} - public Round(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java index 8ab85ac18..58c5d9c20 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java @@ -36,6 +36,10 @@ public class Sign extends BaseTransformSameOp { super(sameDiff, i_v, inPlace); } + public Sign(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sign() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java index 9dbc77bac..c63e00114 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java @@ -35,6 +35,10 @@ public class Square extends BaseTransformSameOp { super(sameDiff, i_v, inPlace); } + public Square(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Square() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 0e5426c3c..1506ac5f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.*; @@ -39,6 +40,11 @@ public class UnsortedSegmentMax extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + public UnsortedSegmentMax(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index b0b7f4457..4338cf33d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentMean extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_mean"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 5b7e1c7e0..2f8aab0b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentMin extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_min"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index bca9e1788..7afd75fac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; @@ -43,6 +44,11 @@ public class UnsortedSegmentProd extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_prod"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index b3a507435..336c756ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -44,6 +45,11 @@ public class UnsortedSegmentSum extends DynamicCustomOp { addIArgument(numSegments); } + public UnsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { + addInputArgument(data, segmentIds); + addIArgument(numSegments); + } + @Override public String opName(){ return "unsorted_segment_sum"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 21a7e5b38..3e0c60bb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ACos extends BaseTransformStrictOp { - public ACos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public ACos(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public ACos() { + public ACos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public ACos(INDArray x) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java index 2e51ea351..a8d9f12ad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,11 +33,9 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ACosh extends BaseTransformStrictOp { - public ACosh() { - } - public ACosh(INDArray x) { super(x); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java index 8716a8f7d..fc514a415 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ASin extends BaseTransformStrictOp { public ASin(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public ASin() { + public ASin(SameDiff sameDiff, SDVariable i_v) { + + this(sameDiff, i_v, false); } public ASin(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index 458d9fad1..483896dfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class ATan extends BaseTransformStrictOp { public ATan(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public ATan() { + public ATan(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public ATan(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 3beb90343..21076ad6e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,15 +33,17 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Cos extends BaseTransformStrictOp { + public Cos(SameDiff sameDiff, SDVariable i_v){ + this(sameDiff, i_v, false); + } + public Cos(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cos() { - } - public Cos(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index b9ada31d0..dc08ead5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,15 +33,17 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Cosh extends BaseTransformStrictOp { + public Cosh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Cosh(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Cosh() { - } - public Cosh(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java index 3d49194ab..2769c95f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erf.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -34,12 +35,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Erf extends BaseTransformStrictOp { - public Erf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Erf(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Erf() { + public Erf(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Erf(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java index 857d87141..f31e71ee8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Erfc.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -36,12 +37,15 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Erfc extends BaseTransformStrictOp { - public Erfc(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Erfc(SameDiff sameDiff, SDVariable i_v){ + this(sameDiff, i_v, false); } - public Erfc() { + public Erfc(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Erfc(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 05dc708a8..21aa49522 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,12 +32,15 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class Exp extends BaseTransformStrictOp { - public Exp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + + public Exp(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Exp() { + public Exp(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Exp(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 5aad7aebd..538f6a003 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Expm1 extends BaseTransformStrictOp { - public Expm1(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public Expm1(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } - public Expm1() { + public Expm1(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { + super(sameDiff, i_v, inPlace); } public Expm1(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index b33ea8b8f..b784ddde0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -33,12 +34,18 @@ import java.util.List; * use precise=false; otherwise, use precise = true for the slower but marginally more accurate tanh version. * @author raver119@gmail.com */ +@NoArgsConstructor public class GELU extends BaseTransformStrictOp { public GELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace, boolean precise) { super(sameDiff, i_v, inPlace); } - public GELU() { + public GELU(SameDiff sameDiff, SDVariable i_v, boolean precise) { + this(sameDiff, i_v, false, precise); + } + + public GELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false, false); } public GELU(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index ddca48d4c..ddaa8631f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,8 +33,8 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class HardSigmoid extends BaseTransformStrictOp { - public HardSigmoid() {} public HardSigmoid(INDArray x, INDArray z) { super(x, z); @@ -47,6 +48,10 @@ public class HardSigmoid extends BaseTransformStrictOp { super(sameDiff, in, inPlace); } + public HardSigmoid(SameDiff sameDiff, SDVariable in){ + this(sameDiff, in, false); + } + @Override public int opNum() { return 36; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index 4237e72de..fa80bf880 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import java.util.Collections; + +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,12 +34,14 @@ import java.util.List; * * @author Adam Gibson */ +@NoArgsConstructor public class HardTanh extends BaseTransformStrictOp { public HardTanh(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public HardTanh() { + public HardTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public HardTanh(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java index 0295b8e52..a937e1d63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java @@ -36,6 +36,10 @@ public class Log extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Log(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Log() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java index 96892a9f0..131986d15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,11 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Log1p extends BaseTransformStrictOp { public Log1p(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Log1p() {} + public Log1p(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } public Log1p(INDArray x, INDArray z) { super(x, z); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 0f4c7abcc..353ced004 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,12 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class LogSigmoid extends BaseTransformStrictOp { public LogSigmoid(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public LogSigmoid() { + public LogSigmoid(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public LogSigmoid(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index f72676f86..00592f0e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -43,6 +43,10 @@ public class SELU extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SELU() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java index 22d5b6302..37ef4b743 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java @@ -36,6 +36,10 @@ public class Sigmoid extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sigmoid(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sigmoid() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java index 22357d386..0fa918c11 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java @@ -37,6 +37,10 @@ public class Sin extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sin(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sin(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java index 84a8f522a..d5e3be988 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java @@ -37,6 +37,10 @@ public class Sinh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Sinh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Sinh(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java index abd1ce904..11ffb2ef8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java @@ -34,6 +34,10 @@ public class SoftPlus extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SoftPlus(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SoftPlus(INDArray x, INDArray z) { super(x, z); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index c7c90b201..8be5ea2d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -40,6 +40,10 @@ public class SoftSign extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public SoftSign(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v, false); + } + public SoftSign() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index 029c7c5b4..0794e0b57 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -32,12 +33,14 @@ import java.util.List; * * @author raver119@gmail.com */ +@NoArgsConstructor public class Swish extends BaseTransformStrictOp { public Swish(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } - public Swish() { + public Swish(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); } public Swish(INDArray x, INDArray z) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java index 77954deec..3244925b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java @@ -38,6 +38,10 @@ public class Tan extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Tan(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public Tan() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java index 667bf6a93..136d0bbea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java @@ -36,6 +36,10 @@ public class Tanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public Tanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v,false); + } + public Tanh() { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java index 752881c6e..0672f5d15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/BaseRandomOp.java @@ -38,6 +38,7 @@ import java.util.List; @NoArgsConstructor public abstract class BaseRandomOp extends BaseOp implements RandomOp { protected long[] shape; + protected DataType dataType = Nd4j.defaultFloatingPointType(); public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) { Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor"); @@ -72,7 +73,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp { @Override public List calculateOutputShape(OpContext opContext) { if(shape != null){ - return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); + return Collections.singletonList(LongShapeDescriptor.fromShape(shape, dataType)); } else { return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType()))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java index 5b9faa005..dfedd6dcd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomExponential.java @@ -36,6 +36,7 @@ import java.util.List; @Slf4j public class RandomExponential extends DynamicCustomOp { private double lambda = 0.0; + private DataType dataType = DataType.DOUBLE; public RandomExponential() { // @@ -48,6 +49,15 @@ public class RandomExponential extends DynamicCustomOp { addTArgument(lambda); } + public RandomExponential(SameDiff sd, double lambda, DataType dataType, long... shape){ + super(null, sd, new SDVariable[]{sd.constant(Nd4j.createFromArray(shape))}); + this.lambda = lambda; + addTArgument(lambda); + this.dataType = dataType; + addDArgument(dataType); + addIArgument(shape); + } + public RandomExponential(double lambda, DataType datatype, long... shape){ this(Nd4j.createFromArray(shape), Nd4j.createUninitialized(datatype, shape), lambda); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java index 3f08d1619..ec4bf96a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.java @@ -44,6 +44,13 @@ public class BernoulliDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.prob}; } + public BernoulliDistribution(SameDiff sd, double prob, DataType dataType, long[] shape){ + this(sd, prob, shape); + this.prob = prob; + this.extraArgs = new Object[] {this.prob}; + super.dataType = dataType; + } + public BernoulliDistribution() { super(); } @@ -113,6 +120,6 @@ public class BernoulliDistribution extends BaseRandomOp { Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape; output data type should be any float //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index b08f56be3..93e4e3c66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -45,6 +45,10 @@ public class BinomialDistribution extends BaseRandomOp { this.extraArgs = new Object[] {(double) this.trials, this.probability}; } + public BinomialDistribution(SameDiff sd, int trials, double probability, DataType dataType, long[] shape){ + this(sd, trials, probability, shape); + } + public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){ this(Nd4j.createUninitialized(dt, shape), trials, probability); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java index 1081e141b..1aa031ec0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/GaussianDistribution.java @@ -36,7 +36,7 @@ import java.util.List; */ public class GaussianDistribution extends BaseRandomOp { private double mean; - private double stddev; + private double stddev; public GaussianDistribution(SameDiff sd, double mean, double stddev, long[] shape){ super(sd, shape); @@ -45,6 +45,14 @@ public class GaussianDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public GaussianDistribution(SameDiff sd, double mean, double stddev, DataType dataType, long[] shape){ + super(sd, shape); + this.mean = mean; + this.stddev = stddev; + this.dataType = dataType; + this.extraArgs = new Object[] {this.mean, this.stddev}; + } + public GaussianDistribution() { super(); } @@ -134,9 +142,7 @@ public class GaussianDistribution extends BaseRandomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); - //Input data type specifies the shape; output data type should be any float - //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java index c007d4e92..44545f8ab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/LogNormalDistribution.java @@ -36,7 +36,7 @@ import java.util.List; */ public class LogNormalDistribution extends BaseRandomOp { private double mean; - private double stddev; + private double stddev; public LogNormalDistribution() { super(); @@ -49,6 +49,11 @@ public class LogNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public LogNormalDistribution(SameDiff sd, double mean, double stdev, DataType dataType, long... shape){ + this(sd, mean, stdev,shape); + this.dataType = dataType; + } + public LogNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), mean, stddev); } @@ -131,9 +136,7 @@ public class LogNormalDistribution extends BaseRandomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); - //Input data type specifies the shape; output data type should be any float - //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java index ba09a2d29..a95169d78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/TruncatedNormalDistribution.java @@ -49,6 +49,13 @@ public class TruncatedNormalDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.mean, this.stddev}; } + public TruncatedNormalDistribution(SameDiff sd, double mean, double stddev, DataType dataType, long[] shape) { + super(sd, shape); + this.mean = mean; + this.stddev = stddev; + this.extraArgs = new Object[] {this.mean, this.stddev}; + } + public TruncatedNormalDistribution(double mean, double stddev, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), mean, stddev); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java index 408af9ce2..e1b40a382 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/UniformDistribution.java @@ -47,6 +47,11 @@ public class UniformDistribution extends BaseRandomOp { this.extraArgs = new Object[] {this.from, this.to}; } + public UniformDistribution(SameDiff sd, double from, double to, DataType dataType, long[] shape) { + this(sd, from, to, shape); + this.dataType = dataType; + } + public UniformDistribution(double min, double max, DataType datatype, long... shape){ this(Nd4j.createUninitialized(datatype, shape), min, max); } @@ -111,6 +116,6 @@ public class UniformDistribution extends BaseRandomOp { Preconditions.checkState(inputDataTypes == null || inputDataTypes.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), inputDataTypes); //Input data type specifies the shape; output data type should be any float //TODO MAKE CONFIGUREABLE - https://github.com/deeplearning4j/deeplearning4j/issues/6854 - return Collections.singletonList(DataType.DOUBLE); + return Collections.singletonList(dataType); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java index f60726c36..4986b8277 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDValidation.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.factory; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -129,6 +130,16 @@ public class NDValidation { " type; got array with non-integer data type " + v.dataType()); } + public static void validateInteger(String opName, String inputName, INDArray[] vars) { + for (INDArray v : vars) { + if (v == null) + return; + if (!v.dataType().isIntType()) + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer" + + " type; got array with non-integer data type " + v.dataType()); + } + } + /** * Validate that the operation is being applied on an floating point type INDArray * @@ -233,4 +244,15 @@ public class NDValidation { public static boolean isSameType(INDArray x, INDArray y) { return x.dataType() == y.dataType(); } + + public static boolean isSameType(INDArray[] x) { + DataType firstDataType = x[0].dataType(); + if (x.length > 1) { + for (int i = 1; i < x.length; ++i) { + if (firstDataType != x[i].dataType()) + return false; + } + } + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java new file mode 100644 index 000000000..cfaf00d18 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -0,0 +1,2056 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.conditions.Condition; + +public class NDBase { + public NDBase() { + } + + /** + * Boolean and array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public INDArray all(INDArray x, int... dimensions) { + NDValidation.validateBool("all", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); + } + + /** + * Boolean or array reduction operation, optionally along specified dimensions
+ * + * @param x Input variable (BOOL type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (BOOL type) + */ + public INDArray any(INDArray x, int... dimensions) { + NDValidation.validateBool("any", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, keepDims, dimensions)); + } + + /** + * Argmax array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the maximum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or + * of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmax(INDArray in, int... dimensions) { + NDValidation.validateNumerical("argmax", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, false, dimensions)); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, keepDims, dimensions)); + } + + /** + * Argmin array reduction operation, optionally along specified dimensions.
+ * Output values are the index of the minimum value of each slice along the specified dimension.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray argmin(INDArray in, int... dimensions) { + NDValidation.validateNumerical("argmin", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); + } + + /** + * Concatenate a set of inputs along the specified dimension.
+ * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
+ * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
+ * + * Inputs must satisfy the following constraints:
+ * Input arrays must all be the same datatype: isSameType(inputs)
+ * + * @param inputs Input variables (NUMERIC type) + * @param dimension Dimension to concatenate on + * @return output (NUMERIC type) + */ + public INDArray concat(INDArray[] inputs, int dimension) { + NDValidation.validateNumerical("concat", "inputs", inputs); + Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0]; + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, int... axis) { + NDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0]; + } + + /** + * Cumulative product operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a*b, a*b*c]
+ * exclusive=true, reverse=false, [0, a, a*b]
+ * exclusive=false, reverse=true: [a*b*c, b*c, c]
+ * exclusive=true, reverse=true: [b*c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray cumprod(INDArray in, int... axis) { + NDValidation.validateNumerical("cumprod", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param exclusive If true: exclude the first value + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, int... axis) { + NDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; + } + + /** + * Cumulative sum operation.
+ * For input: [ a, b, c], output is:
+ * exclusive=false, reverse=false: [a, a+b, a+b+c]
+ * exclusive=true, reverse=false, [0, a, a+b]
+ * exclusive=false, reverse=true: [a+b+c, b+c, c]
+ * exclusive=true, reverse=true: [b+c, c, 0]
+ * + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @return output (NUMERIC type) + */ + public INDArray cumsum(INDArray in, int... axis) { + NDValidation.validateNumerical("cumsum", "in", in); + Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; + } + + /** + * Pairwise dot product reduction along dimension
+ * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
+ * + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output output variable (NUMERIC type) + */ + public INDArray dot(INDArray x, INDArray y, int... dimensions) { + NDValidation.validateNumerical("dot", "x", x); + NDValidation.validateNumerical("dot", "y", y); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.Dot(x, y, dimensions)); + } + + /** + * Dynamically partition the input variable values into the specified number of paritions, using the indices.
+ * Example:
+ *


+ * input = [1,2,3,4,5]
+ * numPartitions = 2
+ * partitions = [1,0,0,1,0]
+ * out[0] = [2,3,5]
+ * out[1] = [1,4] }
+ *

+ * + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 + * @return output Output variables (equal in number to numPartitions) (NUMERIC type) + */ + public INDArray dynamicPartition(INDArray x, INDArray[] partitions, int numPartitions) { + NDValidation.validateNumerical("dynamicPartition", "x", x); + NDValidation.validateInteger("dynamicPartition", "partitions", partitions); + Preconditions.checkArgument(partitions.length >= 1, "partitions has incorrect size/length. Expected: partitions.length >= 1, got %s", partitions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions))[0]; + } + + /** + * Dynamically merge the specified input arrays into a single array, using the specified indices
+ * + * @param x Input variables. (NUMERIC type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @return output Merged output variable (NUMERIC type) + */ + public INDArray dynamicStitch(INDArray[] x, INDArray[] indices) { + NDValidation.validateNumerical("dynamicStitch", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + NDValidation.validateInteger("dynamicStitch", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(x, indices))[0]; + } + + /** + * Equals operation: elementwise x == y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray eq(INDArray x, double y) { + NDValidation.validateNumerical("eq", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals(x, y)); + } + + /** + * Equal to operation: elementwise x == y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray eq(INDArray x, INDArray y) { + NDValidation.validateNumerical("eq", "x", x); + NDValidation.validateNumerical("eq", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo(x, y))[0]; + } + + /** + * Reshape the input by adding a 1 at the specified location.
+ * For example, if input has shape [a, b], then output shape is:
+ * axis = 0: [1, a, b]
+ * axis = 1: [a, 1, b]
+ * axis = 2: [a, b, 1]
+ * + * @param x Input variable (NDARRAY type) + * @param axis Axis to expand + * @return output Output variable (NUMERIC type) + */ + public INDArray expandDims(INDArray x, int axis) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ExpandDims(x, axis))[0]; + } + + /** + * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
+ * + * @param shape Shape: must be a 1D array/variable (INT type) + * @param dataType Datatype of the output array + * @param value Value to set all elements to + * @return output Output variable (NUMERIC type) + */ + public INDArray fill(INDArray shape, DataType dataType, double value) { + NDValidation.validateInteger("fill", "shape", shape); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; + } + + /** + * Gather slices from the input variable where the indices are specified as fixed int[] values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get (Size: AtLeast(min=1)) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public INDArray gather(INDArray df, int[] indices, int axis) { + NDValidation.validateNumerical("gather", "df", df); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + } + + /** + * Gather slices from the input variable where the indices are specified as dynamic array values.
+ * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
+ * + * @param df Input variable (NUMERIC type) + * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) + * @param axis Axis that the indices refer to + * @return output Output variable with slices pulled from the specified axis (NUMERIC type) + */ + public INDArray gather(INDArray df, INDArray indices, int axis) { + NDValidation.validateNumerical("gather", "df", df); + NDValidation.validateInteger("gather", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; + } + + /** + * Gather slices from df with shape specified by indices.
+ * + * @param df (NUMERIC type) + * @param indices (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray gatherNd(INDArray[] df, INDArray[] indices) { + NDValidation.validateNumerical("gatherNd", "df", df); + Preconditions.checkArgument(df.length >= 1, "df has incorrect size/length. Expected: df.length >= 1, got %s", df.length); + NDValidation.validateNumerical("gatherNd", "indices", indices); + Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.GatherNd(df, indices))[0]; + } + + /** + * Greater than operation: elementwise x > y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gt(INDArray x, double y) { + NDValidation.validateNumerical("gt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan(x, y)); + } + + /** + * Greater than operation: elementwise x > y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gt(INDArray x, INDArray y) { + NDValidation.validateNumerical("gt", "x", x); + NDValidation.validateNumerical("gt", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan(x, y))[0]; + } + + /** + * Greater than or equals operation: elementwise x >= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray gte(INDArray x, double y) { + NDValidation.validateNumerical("gte", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); + } + + /** + * Greater than or equal to operation: elementwise x >= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray gte(INDArray x, INDArray y) { + NDValidation.validateNumerical("gte", "x", x); + NDValidation.validateNumerical("gte", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; + } + + /** + * Elementwise identity operation: out = x
+ * + * @param input Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray identity(INDArray input) { + NDValidation.validateNumerical("identity", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Identity(input))[0]; + } + + /** + * Compute the inverse permutation indices for a permutation operation
+ * Example: if input is [2, 0, 1] then output is [1, 2, 0]
+ * The idea is that x.permute(input).permute(invertPermutation(input)) == x
+ * + * @param input 1D indices for permutation (INT type) + * @return output 1D inverted permutation (INT type) + */ + public INDArray invertPermutation(INDArray input) { + NDValidation.validateInteger("invertPermutation", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; + } + + /** + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
+ * + * @param x Input variable (NUMERIC type) + * @return output scalar boolean with value true or false (NDARRAY type) + */ + public INDArray isNumericTensor(INDArray x) { + NDValidation.validateNumerical("isNumericTensor", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor(x))[0]; + } + + /** + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
+ * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
+ * + * @param dataType Data type of the output array + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate + * @return output INDArray with linearly spaced elements (NUMERIC type) + */ + public INDArray linspace(DataType dataType, double start, double stop, long number) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; + } + + /** + * Less than operation: elementwise x < y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lt(INDArray x, double y) { + NDValidation.validateNumerical("lt", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan(x, y)); + } + + /** + * Less than operation: elementwise x < y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lt(INDArray x, INDArray y) { + NDValidation.validateNumerical("lt", "x", x); + NDValidation.validateNumerical("lt", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan(x, y))[0]; + } + + /** + * Less than or equals operation: elementwise x <= y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lte(INDArray x, double y) { + NDValidation.validateNumerical("lte", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); + } + + /** + * Less than or equal to operation: elementwise x <= y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray lte(INDArray x, INDArray y) { + NDValidation.validateNumerical("lte", "x", x); + NDValidation.validateNumerical("lte", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual(x, y))[0]; + } + + /** + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Boolean mask (NUMERIC type) + */ + public INDArray matchCondition(INDArray in, Condition condition) { + NDValidation.validateNumerical("matchCondition", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); + } + + /** + * Returns a count of the number of elements that satisfy the condition
+ * + * @param in Input (NUMERIC type) + * @param condition Condition + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition)); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition, boolean keepDim, + int... dimensions) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, dimensions)); + } + + /** + * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Number of elements that the condition is satisfied for (NUMERIC type) + */ + public INDArray matchConditionCount(INDArray in, Condition condition, int... dimensions) { + NDValidation.validateNumerical("matchConditionCount", "in", in); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, dimensions)); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray max(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, keepDims, dimensions)); + } + + /** + * Max array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray max(INDArray x, int... dimensions) { + NDValidation.validateNumerical("max", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, false, dimensions)); + } + + /** + * Element-wise maximum operation: out[i] = max(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray max(INDArray first, INDArray second) { + NDValidation.validateNumerical("max", "first", first); + NDValidation.validateNumerical("max", "second", second); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(first, second))[0]; + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray mean(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); + } + + /** + * Mean (average) array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray mean(INDArray x, int... dimensions) { + NDValidation.validateNumerical("mean", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray min(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, keepDims, dimensions)); + } + + /** + * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray min(INDArray x, int... dimensions) { + NDValidation.validateNumerical("min", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, false, dimensions)); + } + + /** + * Element-wise minimum operation: out[i] = min(first[i], second[i])
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * @param first First input array (NUMERIC type) + * @param second Second input array (NUMERIC type) + * @return output Second input array (NUMERIC type) + */ + public INDArray min(INDArray first, INDArray second) { + NDValidation.validateNumerical("min", "first", first); + NDValidation.validateNumerical("min", "second", second); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(first, second))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0]; + } + + /** + * Not equals operation: elementwise x != y
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input array (NUMERIC type) + * @param y Double value argument to use in operation + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray neq(INDArray x, double y) { + NDValidation.validateNumerical("neq", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals(x, y)); + } + + /** + * Not equal to operation: elementwise x != y
+ * If x and y arrays have equal shape, the output shape is the same as these inputs.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * + * Return boolean array with values true where satisfied, or false otherwise.
+ * + * @param x Input 1 (NUMERIC type) + * @param y Input 2 (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + */ + public INDArray neq(INDArray x, INDArray y) { + NDValidation.validateNumerical("neq", "x", x); + NDValidation.validateNumerical("neq", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo(x, y))[0]; + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm1(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); + } + + /** + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
+ * out = sum_i abs(x[i])
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm1(INDArray x, int... dimensions) { + NDValidation.validateNumerical("norm1", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, false, dimensions)); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm2(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); + } + + /** + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
+ * out = sqrt(sum_i x[i]^2)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray norm2(INDArray x, int... dimensions) { + NDValidation.validateNumerical("norm2", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, false, dimensions)); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray normmax(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); + } + + /** + * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
+ * specified dimensions:
+ * out = max(abs(x[i]))
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray normmax(INDArray x, int... dimensions) { + NDValidation.validateNumerical("normmax", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @param dataType Output data type + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off, + DataType dataType) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; + } + + /** + * Convert the array to a one-hot array with walues and for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with {out[i, ..., j, in[i,...,j]] with other values being set to
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0]; + } + + /** + * Convert the array to a one-hot array with walues 0 and 1 for each entry
+ * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
+ * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
+ * see oneHot(SDVariable, int, int, double, double)
+ * + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @return output Output variable (NUMERIC type) + */ + public INDArray oneHot(INDArray indices, int depth) { + NDValidation.validateNumerical("oneHot", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth))[0]; + } + + /** + * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input INDArray (NUMERIC type) + * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) + */ + public INDArray onesLike(INDArray input) { + NDValidation.validateNumerical("onesLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input))[0]; + } + + /** + * As per onesLike(String, SDVariable) but the output datatype may be specified
+ * + * @param input (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public INDArray onesLike(INDArray input, DataType dataType) { + NDValidation.validateNumerical("onesLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OnesLike(input, dataType))[0]; + } + + /** + * Array permutation operation: permute the dimensions according to the specified permutation indices.
+ * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output Output variable (permuted input) (NUMERIC type) + */ + public INDArray permute(INDArray x, int... dimensions) { + NDValidation.validateNumerical("permute", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray prod(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, keepDims, dimensions)); + } + + /** + * Product array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray prod(INDArray x, int... dimensions) { + NDValidation.validateNumerical("prod", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, false, dimensions)); + } + + /** + * Create a new variable with a 1d array, where the values start at from and increment by step
+ * up to (but not including) limit.
+ * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
+ * + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType + * @return output INDArray with the specified values (NUMERIC type) + */ + public INDArray range(double from, double to, double step, DataType dataType) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.impl.Range(from, to, step, dataType))[0]; + } + + /** + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + */ + public INDArray rank(INDArray in) { + NDValidation.validateNumerical("rank", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Rank(in))[0]; + } + + /** + * Element-wise replace where condition:
+ * out[i] = from[i] if condition(update[i]) is satisfied, or
+ * out[i] = update[i] if condition(update[i]) is NOT satisfied
+ * + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param condition Condition to check on update array elements + * @return output New array with values replaced where condition is satisfied (NUMERIC type) + */ + public INDArray replaceWhere(INDArray update, INDArray from, Condition condition) { + NDValidation.validateNumerical("replaceWhere", "update", update); + NDValidation.validateNumerical("replaceWhere", "from", from); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); + } + + /** + * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
+ * input, but with the specified shape.
+ * Note that prod(shape) must match length(input) == prod(input.shape)
+ * + * @param x Input variable (NUMERIC type) + * @param shape New shape for variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray reshape(INDArray x, INDArray shape) { + NDValidation.validateNumerical("reshape", "x", x); + NDValidation.validateNumerical("reshape", "shape", shape); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; + } + + /** + * Reverse the values of an array for the specified dimensions
+ * If input is:
+ * [ 1, 2, 3]
+ * [ 4, 5, 6]
+ * then
+ * reverse(in, 0):
+ * [3, 2, 1]
+ * [6, 5, 4]
+ * reverse(in, 1):
+ * [4, 5, 6]
+ * [1, 2 3]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Input variable (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray reverse(INDArray x, int... dimensions) { + NDValidation.validateNumerical("reverse", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0]; + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @param seqDim Sequence dimension + * @param batchDim Batch dimension + * @return output Reversed sequences (NUMERIC type) + */ + public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { + NDValidation.validateNumerical("reverseSequence", "x", x); + NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0]; + } + + /** + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
+ * + * @param x Input variable (NUMERIC type) + * @param seq_lengths Length of the sequences (INT type) + * @return output Reversed sequences (NUMERIC type) + */ + public INDArray reverseSequence(INDArray x, INDArray seq_lengths) { + NDValidation.validateNumerical("reverseSequence", "x", x); + NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0]; + } + + /** + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
+ * i.e., returns the remainder after division by 'value'
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarFloorMod(INDArray in, double value) { + NDValidation.validateNumerical("scalarFloorMod", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(in, value)); + } + + /** + * Element-wise scalar maximum operation: out = max(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Scalar value to compare (NUMERIC type) + */ + public INDArray scalarMax(INDArray in, double value) { + NDValidation.validateNumerical("scalarMax", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMax(in, value)); + } + + /** + * Element-wise scalar minimum operation: out = min(in, value)
+ * + * @param in Input variable (NUMERIC type) + * @param value Scalar value to compare + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarMin(INDArray in, double value) { + NDValidation.validateNumerical("scalarMin", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMin(in, value)); + } + + /** + * Return a variable with equal shape to the input, but all elements set to value 'set'
+ * + * @param in Input variable (NUMERIC type) + * @param set Value to set + * @return output Output variable (NUMERIC type) + */ + public INDArray scalarSet(INDArray in, double set) { + NDValidation.validateNumerical("scalarSet", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarSet(in, set)); + } + + /** + * Scatter addition operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterAdd(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterAdd", "ref", ref); + NDValidation.validateNumerical("scatterAdd", "indices", indices); + NDValidation.validateNumerical("scatterAdd", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd(ref, indices, updates))[0]; + } + + /** + * Scatter division operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterDiv(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterDiv", "ref", ref); + NDValidation.validateNumerical("scatterDiv", "indices", indices); + NDValidation.validateNumerical("scatterDiv", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv(ref, indices, updates))[0]; + } + + /** + * Scatter max operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMax(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMax", "ref", ref); + NDValidation.validateNumerical("scatterMax", "indices", indices); + NDValidation.validateNumerical("scatterMax", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMax(ref, indices, updates))[0]; + } + + /** + * Scatter min operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMin(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMin", "ref", ref); + NDValidation.validateNumerical("scatterMin", "indices", indices); + NDValidation.validateNumerical("scatterMin", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMin(ref, indices, updates))[0]; + } + + /** + * Scatter multiplication operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterMul(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterMul", "ref", ref); + NDValidation.validateNumerical("scatterMul", "indices", indices); + NDValidation.validateNumerical("scatterMul", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterMul(ref, indices, updates))[0]; + } + + /** + * Scatter subtraction operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterSub(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterSub", "ref", ref); + NDValidation.validateNumerical("scatterSub", "indices", indices); + NDValidation.validateNumerical("scatterSub", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterSub(ref, indices, updates))[0]; + } + + /** + * Scatter update operation.
+ * + * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
+ * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
+ * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
+ * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
+ * + * @param ref Initial/source variable (NUMERIC type) + * @param indices Indices array (NUMERIC type) + * @param updates Updates to add to the initial/source array (NUMERIC type) + * @return output The updated variable (NUMERIC type) + */ + public INDArray scatterUpdate(INDArray ref, INDArray indices, INDArray updates) { + NDValidation.validateNumerical("scatterUpdate", "ref", ref); + NDValidation.validateNumerical("scatterUpdate", "indices", indices); + NDValidation.validateNumerical("scatterUpdate", "updates", updates); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; + } + + /** + * Segment max operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMax(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; + } + + /** + * Segment mean operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMean(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0]; + } + + /** + * Segment min operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentMin(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; + } + + /** + * Segment product operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentProd(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0]; + } + + /** + * Segment sum operation.
+ * + * If data = [3, 6, 1, 4, 9, 2, 8]
+ * segmentIds = [0, 0, 1, 1, 1, 2, 2]
+ * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
+ * Note that the segment IDs must be sorted from smallest to largest segment.
+ * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
+ * for the same op without this sorted requirement
+ * + * @param data Data to perform segment max on (NDARRAY type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @return output Segment output (NUMERIC type) + */ + public INDArray segmentSum(INDArray data, INDArray segmentIds) { + NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; + } + + /** + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
+ * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType + * @return output Output variable (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + } + + /** + * see sequenceMask(String, SDVariable, SDVariable, DataType)
+ * + * @param lengths (NUMERIC type) + * @param dataType + * @return output (NUMERIC type) + */ + public INDArray sequenceMask(INDArray lengths, DataType dataType) { + NDValidation.validateNumerical("sequenceMask", "lengths", lengths); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, dataType))[0]; + } + + /** + * Returns the shape of the specified INDArray as a 1D INDArray
+ * + * @param input Input variable (NUMERIC type) + * @return output 1D output variable with contents equal to the shape of the input (NUMERIC type) + */ + public INDArray shape(INDArray input) { + NDValidation.validateNumerical("shape", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Shape(input))[0]; + } + + /** + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
+ * + * @param in Input variable (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + */ + public INDArray size(INDArray in) { + NDValidation.validateNumerical("size", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Size(in))[0]; + } + + /** + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
+ * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
+ * + * @param in Input variable (NUMERIC type) + * @param dimension Dimension to get size of + * @return output Scalar INDArray for size at specified variable (NUMERIC type) + */ + public INDArray sizeAt(INDArray in, int dimension) { + NDValidation.validateNumerical("sizeAt", "in", in); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SizeAt(in, dimension))[0]; + } + + /** + * Get a subset of the specified input, by specifying the first element and the size of the array.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * then slice(input, begin=[0,1], size=[2,1] will return:
+ * [b]
+ * [e]
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * + * @param input input Variable to get subset of (NUMERIC type) + * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @return output Subset of the input (NUMERIC type) + */ + public INDArray slice(INDArray input, int[] begin, int... size) { + NDValidation.validateNumerical("slice", "input", input); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray squaredNorm(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); + } + + /** + * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) + * @return output (NUMERIC type) + */ + public INDArray squaredNorm(INDArray x, int... dimensions) { + NDValidation.validateNumerical("squaredNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); + } + + /** + * Remove a single dimension of size 1.
+ * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
+ * + * @param x Input variable (NUMERIC type) + * @param axis Size 1 dimension to remove + * @return output Output variable (NUMERIC type) + */ + public INDArray squeeze(INDArray x, int axis) { + NDValidation.validateNumerical("squeeze", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Squeeze(x, axis))[0]; + } + + /** + * Stack a set of N INDArray of rank X into one rank X+1 variable.
+ * If inputs have shape [a,b,c] then output has shape:
+ * axis = 0: [N,a,b,c]
+ * axis = 1: [a,N,b,c]
+ * axis = 2: [a,b,N,c]
+ * axis = 3: [a,b,c,N]
+ * see unstack(String[], SDVariable, int, int)
+ * + * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) + * @param axis Axis to stack on + * @return output Output variable (NDARRAY type) + */ + public INDArray stack(INDArray values, int axis) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray standardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, + int... dimensions) { + NDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, dimensions)); + } + + /** + * Stardard deviation array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray standardDeviation(INDArray x, boolean biasCorrected, int... dimensions) { + NDValidation.validateNumerical("standardDeviation", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, dimensions)); + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @return output A subset of the input array (NUMERIC type) + */ + public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int[] strides, int beginMask, + int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { + NDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + } + + /** + * Get a subset of the specified input, by specifying the first element, last element, and the strides.
+ * For example, if input is:
+ * [a, b, c]
+ * [d, e, f]
+ * [g, h, i]
+ * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
+ * [b, c]
+ * [h, i]
+ * + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @return output A subset of the input array (NUMERIC type) + */ + public INDArray stridedSlice(INDArray in, int[] begin, int[] end, int... strides) { + NDValidation.validateNumerical("stridedSlice", "in", in); + Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray sum(INDArray x, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, keepDims, dimensions)); + } + + /** + * Sum array reduction operation, optionally along specified dimensions.
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + */ + public INDArray sum(INDArray x, int... dimensions) { + NDValidation.validateNumerical("sum", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output Output variable (NUMERIC type) + */ + public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, + boolean transposeX, boolean transposeY, boolean transposeZ) { + NDValidation.validateNumerical("tensorMmul", "x", x); + NDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0]; + } + + /** + * //TODO: Ops must be documented.
+ * + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) + * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) + * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) + * @return output Output variable (NUMERIC type) + */ + public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dimensionsY) { + NDValidation.validateNumerical("tensorMmul", "x", x); + NDValidation.validateNumerical("tensorMmul", "y", y); + Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0]; + } + + /** + * Repeat (tile) the input tensor the specified number of times.
+ * For example, if input is
+ * [1, 2]
+ * [3, 4]
+ * and repeat is [2, 3]
+ * then output is
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * [1, 2, 1, 2, 1, 2]
+ * [3, 4, 3, 4, 3, 4]
+ * + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @return output Output variable (NDARRAY type) + */ + public INDArray tile(INDArray x, INDArray repeat) { + NDValidation.validateInteger("tile", "repeat", repeat); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + } + + /** + * see tile(String, SDVariable, int...)
+ * + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) + * @return output (NDARRAY type) + */ + public INDArray tile(INDArray x, int... repeat) { + Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; + } + + /** + * Matrix transpose operation: If input has shape [a,b] output has shape [b,a]
+ * + * @param x Input variable (NDARRAY type) + * @return output transposed input (NDARRAY type) + */ + public INDArray transpose(INDArray x) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Transpose(x))[0]; + } + + /** + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMax", "data", data); + NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMean", "data", data); + NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentMin", "data", data); + NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentProd", "data", data); + NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); + NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + } + + /** + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
+ * the requirement for the indices to be sorted.
+ * If data = [1, 3, 2, 6, 4, 9, 8]
+ * segmentIds = [1, 0, 2, 0, 1, 1, 2]
+ * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
+ * + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param numSegments Number of segments + * @return output Unsorted segment output (NUMERIC type) + */ + public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { + NDValidation.validateNumerical("unsortedSegmentSum", "data", data); + NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { + NDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, dimensions)); + } + + /** + * Variance array reduction operation, optionally along specified dimensions
+ * + * Note that if keepDims = true, the output variable has the same rank as the input variable,
+ * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
+ * the mean along a dimension).
+ * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
+ * keepDims = true: [a,1,c]
+ * keepDims = false: [a,c]
+ * + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) + */ + public INDArray variance(INDArray x, boolean biasCorrected, int... dimensions) { + NDValidation.validateNumerical("variance", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, dimensions)); + } + + /** + * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
+ * if the input shape changes in later execution, the returned variable's shape will also be updated
+ * + * @param input Input (NUMERIC type) + * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) + */ + public INDArray zerosLike(INDArray input) { + NDValidation.validateNumerical("zerosLike", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.ZerosLike(input))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java index f77d5c823..d874b5bbf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBitwise.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java old mode 100755 new mode 100644 index 7bee44ace..cb00a28c2 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.enums.DataFormat; +import org.nd4j.enums.DataFormat; public class NDCNN { public NDCNN() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java new file mode 100644 index 000000000..cb80c8092 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java @@ -0,0 +1,274 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.linalg.factory.ops; + +import static org.nd4j.linalg.factory.NDValidation.isSameType; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.NDValidation; +import org.nd4j.linalg.factory.Nd4j; + +public class NDLinalg { + public NDLinalg() { + } + + /** + * Computes the Cholesky decomposition of one or more square matrices.
+ * + * @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type) + * @return output Transformed tensor (NUMERIC type) + */ + public INDArray cholesky(INDArray input) { + NDValidation.validateNumerical("Cholesky", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(input))[0]; + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @param fast fast mode, defaults to True + * @return output Transformed tensor (FLOATING_POINT type) + */ + public INDArray lstsq(INDArray matrix, INDArray rhs, double l2_reguralizer, boolean fast) { + NDValidation.validateNumerical("Lstsq", "matrix", matrix); + NDValidation.validateNumerical("Lstsq", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lstsq(matrix, rhs, l2_reguralizer, fast))[0]; + } + + /** + * Solver for linear squares problems.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param l2_reguralizer regularizer + * @return output Transformed tensor (FLOATING_POINT type) + */ + public INDArray lstsq(INDArray matrix, INDArray rhs, double l2_reguralizer) { + NDValidation.validateNumerical("Lstsq", "matrix", matrix); + NDValidation.validateNumerical("Lstsq", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lstsq(matrix, rhs, l2_reguralizer, true))[0]; + } + + /** + * Computes LU decomposition.
+ * + * @param input input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray lu(INDArray input) { + NDValidation.validateNumerical("Lu", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Lu(input))[0]; + } + + /** + * Performs matrix mutiplication on input tensors.
+ * + * @param a input tensor (NUMERIC type) + * @param b input tensor (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray matmul(INDArray a, INDArray b) { + NDValidation.validateNumerical("Matmul", "a", a); + NDValidation.validateNumerical("Matmul", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(a, b))[0]; + } + + /** + * Copy a tensor setting outside a central band in each innermost matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param minLower lower diagonal count + * @param maxUpper upper diagonal count + */ + public INDArray[] matrixBandPart(INDArray input, int minLower, int maxUpper) { + NDValidation.validateNumerical("MatrixBandPart", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.MatrixBandPart(input, minLower, maxUpper)); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + * @param full full matrices mode + */ + public INDArray[] qr(INDArray input, boolean full) { + NDValidation.validateNumerical("Qr", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(input, full)); + } + + /** + * Computes the QR decompositions of input matrix.
+ * + * @param input input tensor (NUMERIC type) + */ + public INDArray[] qr(INDArray input) { + NDValidation.validateNumerical("Qr", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(input, false)); + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param adjoint adjoint mode, defaults to False + * @return output Output tensor (FLOATING_POINT type) + */ + public INDArray solve(INDArray matrix, INDArray rhs, boolean adjoint) { + NDValidation.validateNumerical("Solve", "matrix", matrix); + NDValidation.validateNumerical("Solve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.LinearSolve(matrix, rhs, adjoint))[0]; + } + + /** + * Solver for systems of linear equations.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @return output Output tensor (FLOATING_POINT type) + */ + public INDArray solve(INDArray matrix, INDArray rhs) { + NDValidation.validateNumerical("Solve", "matrix", matrix); + NDValidation.validateNumerical("Solve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.LinearSolve(matrix, rhs, false))[0]; + } + + /** + * Solver for systems of linear questions.
+ * + * @param matrix input tensor (NUMERIC type) + * @param rhs input tensor (NUMERIC type) + * @param lower defines whether innermost matrices in matrix are lower or upper triangular + * @param adjoint adjoint mode + * @return output (FLOATING_POINT type) + */ + public INDArray triangularSolve(INDArray matrix, INDArray rhs, boolean lower, boolean adjoint) { + NDValidation.validateNumerical("TriangularSolve", "matrix", matrix); + NDValidation.validateNumerical("TriangularSolve", "rhs", rhs); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.TriangularSolve(matrix, rhs, lower, adjoint))[0]; + } + + /** + * Computes pairwise cross product.
+ * + * @param a (NUMERIC type) + * @param b (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray cross(INDArray a, INDArray b) { + NDValidation.validateNumerical("cross", "a", a); + NDValidation.validateNumerical("cross", "b", b); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Cross(a, b))[0]; + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray diag(INDArray input) { + NDValidation.validateNumerical("diag", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Diag(input))[0]; + } + + /** + * Calculates diagonal tensor.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray diag_part(INDArray input) { + NDValidation.validateNumerical("diag_part", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(input))[0]; + } + + /** + * Calculates log of determinant.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray logdet(INDArray input) { + NDValidation.validateNumerical("logdet", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Logdet(input))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, + boolean transposeZ) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + } + + /** + * Matrix multiplication: out = mmul(x,y)
+ * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
+ * + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) + * @return output (NUMERIC type) + */ + public INDArray mmul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mmul", "x", x); + NDValidation.validateNumerical("mmul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, false, false, false))[0]; + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @param switchNum + * @return output (FLOATING_POINT type) + */ + public INDArray svd(INDArray input, boolean fullUV, boolean computeUV, int switchNum) { + NDValidation.validateNumerical("svd", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, switchNum))[0]; + } + + /** + * Calculates singular value decomposition.
+ * + * @param input (NUMERIC type) + * @param fullUV + * @param computeUV + * @return output (FLOATING_POINT type) + */ + public INDArray svd(INDArray input, boolean fullUV, boolean computeUV) { + NDValidation.validateNumerical("svd", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, 16))[0]; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java index 4c1234514..cdee59ea1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLoss.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2019 Konduit K.K. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index 66f8071e2..eddbe3db7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,6 +18,8 @@ package org.nd4j.linalg.factory.ops; +import static org.nd4j.linalg.factory.NDValidation.isSameType; + import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -66,12 +68,12 @@ public class NDMath { * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amax(INDArray in, int... dimensions) { NDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(in, dimensions)); } @@ -79,12 +81,12 @@ public class NDMath { * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amean(INDArray in, int... dimensions) { NDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(in, dimensions)); } @@ -92,12 +94,12 @@ public class NDMath { * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray amin(INDArray in, int... dimensions) { NDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(in, dimensions)); } @@ -143,12 +145,12 @@ public class NDMath { * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray asum(INDArray in, int... dimensions) { NDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(in, dimensions)); } @@ -375,12 +377,12 @@ public class NDMath { * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray countNonZero(INDArray in, int... dimensions) { NDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(in, dimensions)); } @@ -388,12 +390,12 @@ public class NDMath { * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray countZero(INDArray in, int... dimensions) { NDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(in, dimensions)); } @@ -461,12 +463,12 @@ public class NDMath { * Entropy reduction: -sum(x * log(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray entropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(in, dimensions)); } @@ -566,10 +568,12 @@ public class NDMath { * @param rows Number of rows * @param cols Number of columns * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ - public INDArray eye(int rows, int cols, DataType dataType) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType))[0]; + public INDArray eye(int rows, int cols, DataType dataType, int... dimensions) { + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Eye(rows, cols, dataType, dimensions))[0]; } /** @@ -615,7 +619,7 @@ public class NDMath { public INDArray firstIndex(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("firstIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, false, condition, dimensions)); } /** @@ -639,7 +643,7 @@ public class NDMath { int... dimensions) { NDValidation.validateNumerical("firstIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, condition, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(in, keepDims, condition, dimensions)); } /** @@ -682,7 +686,7 @@ public class NDMath { public INDArray iamax(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, false, dimensions)); } /** @@ -711,7 +715,7 @@ public class NDMath { public INDArray iamin(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, false, dimensions)); } /** @@ -842,7 +846,7 @@ public class NDMath { public INDArray lastIndex(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("lastIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, false, condition, dimensions)); } /** @@ -865,7 +869,7 @@ public class NDMath { public INDArray lastIndex(INDArray in, Condition condition, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("lastIndex", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, condition, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(in, keepDims, condition, dimensions)); } /** @@ -883,13 +887,12 @@ public class NDMath { * Element-wise logarithm function (with specified base): out = log_{base}(x)
* * @param x Input variable (NUMERIC type) - * @param base Logarithm base (NUMERIC type) + * @param base Logarithm base * @return output Output variable (NUMERIC type) */ - public INDArray log(INDArray x, INDArray base) { + public INDArray log(INDArray x, double base) { NDValidation.validateNumerical("log", "x", x); - NDValidation.validateNumerical("log", "base", base); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(x, base)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.LogX(x, base)); } /** @@ -907,12 +910,12 @@ public class NDMath { * Log entropy reduction: log(-sum(x * log(x)))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray logEntropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(in, dimensions)); } @@ -1017,12 +1020,11 @@ public class NDMath { * * @param input Input to calculate moments for (NUMERIC type) * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) - * @return output Mean and variance variables (NUMERIC type) */ - public INDArray moments(INDArray input, int... axes) { + public INDArray[] moments(INDArray input, int... axes) { NDValidation.validateNumerical("moments", "input", input); Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes)); } /** @@ -1043,14 +1045,13 @@ public class NDMath { * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) - * @return output Output variables: mean and population variance (NUMERIC type) */ - public INDArray normalizeMoments(INDArray counts, INDArray means, INDArray variances, + public INDArray[] normalizeMoments(INDArray counts, INDArray means, INDArray variances, double shift) { NDValidation.validateNumerical("normalizeMoments", "counts", counts); NDValidation.validateNumerical("normalizeMoments", "means", means); NDValidation.validateNumerical("normalizeMoments", "variances", variances); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(counts, means, variances, shift)); } /** @@ -1153,12 +1154,12 @@ public class NDMath { * Shannon Entropy reduction: -sum(x * log2(x))
* * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray shannonEntropy(INDArray in, int... dimensions) { NDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(in, dimensions)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 815f22e5b..04a713ecf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -107,7 +107,7 @@ public class NDNN { NDValidation.validateNumerical("dotProductAttention", "keys", keys); NDValidation.validateNumerical("dotProductAttention", "values", values); NDValidation.validateNumerical("dotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(queries, keys, values, mask, scaled, false))[0]; } /** @@ -227,7 +227,7 @@ public class NDNN { NDValidation.validateNumerical("layerNorm", "input", input); NDValidation.validateNumerical("layerNorm", "gain", gain); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, channelsFirst, dimensions))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(input, gain, null, channelsFirst, dimensions))[0]; } /** @@ -343,7 +343,7 @@ public class NDNN { NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java old mode 100755 new mode 100644 diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java index 1dfcd60ae..dc5e472e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRandom.java @@ -1,5 +1,5 @@ -/* ****************************************************************************** - * Copyright (c) 2019 Konduit K.K. +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,6 +18,8 @@ package org.nd4j.linalg.factory.ops; +import static org.nd4j.linalg.factory.NDValidation.isSameType; + import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -67,11 +69,12 @@ public class NDRandom { * @param lambda lambda parameter * @param datatype Data type of the output variable * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0)) + * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type) */ - public INDArray[] exponential(double lambda, DataType datatype, long... shape) { + public INDArray exponential(double lambda, DataType datatype, long... shape) { Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); Preconditions.checkArgument(lambda > 0, "Must be positive"); - return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.RandomExponential(lambda, datatype, shape))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9abe0a483..eab974821 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -257,7 +257,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, true, 2, 2); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -588,8 +588,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().sconv2d(vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.cnn().separableConv2d(in, dW, b, c); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -623,7 +623,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable pW = sd.var("pW", pointWeightArr); SDVariable b = sd.var("b", bArr); - SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; + //SDVariable[] vars = new SDVariable[]{in, dW, pW, b}; Conv2DConfig c = Conv2DConfig.builder() .kH(kH).kW(kW) @@ -634,8 +634,8 @@ public class LayerOpValidation extends BaseOpValidation { .dataFormat(Conv2DConfig.NCHW) .build(); - SDVariable out = sd.cnn().sconv2d(vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.cnn().separableConv2d(in, dW, pW, b, c); + out = sd.nn().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 @@ -685,8 +685,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().deconv2d(vars, deconv); - out = sd.nn().tanh("out", out); + SDVariable out = sd.f().deconv2d(vars, deconv); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 @@ -733,8 +733,8 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(false) .build(); - SDVariable out = sd.cnn().conv2d("conv", vars, c); - out = sd.nn().tanh("out", out); + SDVariable out = sd.f().conv2d(vars, c); + out = sd.f().tanh(out); INDArray outArr = out.eval(); //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 @@ -767,7 +767,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig); + SDVariable[] results = sd.f().maxPoolWithArgmax(/*new String[]{"out","idx"},*/ in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -797,7 +797,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); - SDVariable out = sd.nn().tanh("out", outPool); + SDVariable out = sd.f().tanh(/*"out",*/ outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -855,7 +855,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); - SDVariable out = sd.nn().tanh("out", outPool); + SDVariable out = sd.f().tanh(/*"out",*/ outPool); INDArray outArr = out.eval(); val outShape = outArr.shape(); @@ -906,7 +906,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().avgPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.f().tanh(/*"loss", */out).shape().rename("out"); // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; INDArray outArr = Nd4j.createFromArray(mb, nIn, 4, 4, 4L); @@ -942,7 +942,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().maxPooling3d(in, pooling3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -976,8 +976,8 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1018,7 +1018,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv1d(in, w, b, conv1DConfig); - SDVariable loss = sd.nn().tanh(out).std(true).rename("loss"); + SDVariable loss = sd.f().tanh(out).std(true).rename("loss"); sd.setLossVariables("loss"); @@ -1057,7 +1057,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); + SDVariable res = sd.cnn.conv1d(in, w, null, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ @@ -1113,7 +1113,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().conv3d(in, w, b, conv3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1156,7 +1156,7 @@ public class LayerOpValidation extends BaseOpValidation { .build(); SDVariable out = sd.cnn().deconv3d(in, w, conv3DConfig); - out = sd.nn().tanh("loss", out).shape().rename("out"); + out = sd.math().tanh("loss", out).shape().rename("out"); sd.setLossVariables("loss"); @@ -1335,7 +1335,7 @@ public class LayerOpValidation extends BaseOpValidation { .paddingMode(PaddingMode.VALID) .build(); - SDVariable out = sd.cnn().conv1d(in, w, conv1DConfig); + SDVariable out = sd.cnn().conv1d(in, w, null, conv1DConfig); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 7f8da282e..ca3f10d04 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -178,7 +178,7 @@ public class LossOpValidation extends BaseOpValidation { predictionsArr = Transforms.log(Transforms.abs(predictionsArr)); labelsArr = Transforms.abs(labelsArr); expOut = Transforms.exp(predictionsArr).sub(labelsArr.mul(predictionsArr)); - loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction); + loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction,false); break; case "log_poisson_full": predictionsArr = Transforms.log(Transforms.abs(predictionsArr)); @@ -188,7 +188,7 @@ public class LossOpValidation extends BaseOpValidation { .add(labelsArr.mul(Transforms.log(labelsArr))) .sub(labelsArr) .add(Transforms.log(labelsArr.mul(Math.PI * 2)).mul(0.5)); - loss = sd.loss().logPoissonFull("loss", labels, predictions, w, reduction); + loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction,true); break; case "mse": //To match TF, this is actually sum of squares - 1/numExamples (prediction-label)^2 @@ -251,7 +251,7 @@ public class LossOpValidation extends BaseOpValidation { expOut.muli(1/((n*(n-1)) / 2)); - loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction); + loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions,w, reduction); break; case "sparsesoftmax": labelsArr = Nd4j.create(DataType.DOUBLE, minibatch); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 1f23e12ec..06c64445b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -1289,7 +1289,7 @@ public class MiscOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable var = sd.var("in", Nd4j.create(new long[]{1}).assign(5)); - SDVariable merged = sd.math().mergeAvg("merged", var); + SDVariable merged = sd.math().mergeAvg("merged", new SDVariable[]{var}); SDVariable sum = sd.sum(merged); Map m = sd.output(Collections.emptyMap(), "merged"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 4585b4a15..053f3a70b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -71,7 +71,7 @@ public class RandomOpValidation extends BaseOpValidation { switch (i) { case 0: name = "randomUniform"; - rand = sd.random().uniform(1, 2, shapeVar); + rand = sd.random().uniform(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double min = in.minNumber().doubleValue(); double max = in.maxNumber().doubleValue(); @@ -83,7 +83,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 1: name = "randomNormal"; - rand = sd.random().normal(1, 1, shapeVar); + rand = sd.random().normal(1, 1, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -94,7 +94,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 2: name = "randomBernoulli"; - rand = sd.random().bernoulli(0.5, shapeVar); + rand = sd.random().bernoulli(0.5, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -110,7 +110,7 @@ public class RandomOpValidation extends BaseOpValidation { case 3: name = "randomExponential"; final double lambda = 2; - rand = sd.random().exponential(lambda, shapeVar); + rand = sd.random().exponential(lambda, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -168,7 +168,7 @@ public class RandomOpValidation extends BaseOpValidation { switch (i) { case 0: name = "randomBernoulli"; - rand = sd.random().bernoulli(0.5, shape); + rand = sd.random().bernoulli(0.5, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double min = in.minNumber().doubleValue(); @@ -183,7 +183,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 1: name = "normal"; - rand = sd.random().normal(1, 2, shape); + rand = sd.random().normal(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -194,7 +194,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 2: name = "randomBinomial"; - rand = sd.random().binomial(4, 0.5, shape); + rand = sd.random().binomial(4, 0.5, DataType.DOUBLE, shape); checkFn = in -> { NdIndexIterator iter = new NdIndexIterator(in.shape()); while(iter.hasNext()){ @@ -209,7 +209,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 3: name = "randomUniform"; - rand = sd.random().uniform(1, 2, shape); + rand = sd.random().uniform(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double min = in.minNumber().doubleValue(); double max = in.maxNumber().doubleValue(); @@ -225,7 +225,7 @@ public class RandomOpValidation extends BaseOpValidation { continue; } name = "truncatednormal"; - rand = sd.random().normalTruncated(1, 2, shape); + rand = sd.random().normalTruncated(1, 2, DataType.DOUBLE, shape); checkFn = in -> { double mean = in.meanNumber().doubleValue(); double stdev = in.std(true).getDouble(0); @@ -236,7 +236,7 @@ public class RandomOpValidation extends BaseOpValidation { break; case 5: name = "lognormal"; - rand = sd.random().logNormal(1, 2, shape); + rand = sd.random().logNormal(1, 2, DataType.DOUBLE, shape); //Note: lognormal parameters are mean and stdev of LOGARITHM of values checkFn = in -> { INDArray log = Transforms.log(in, true); @@ -389,15 +389,25 @@ public class RandomOpValidation extends BaseOpValidation { for(DataType t : new DataType[]{DataType.FLOAT, DataType.DOUBLE, }){ SameDiff sd = SameDiff.create(); SDVariable shape = sd.constant("shape", Nd4j.createFromArray(1, 100)); - SDVariable out = sd.random.uniform(0, 10, shape, t); + SDVariable out = sd.random.uniform(0, 10, t, 1, 100); INDArray arr = out.eval(); assertEquals(t, arr.dataType()); - double min = arr.minNumber().doubleValue(); - double max = arr.maxNumber().doubleValue(); - double mean = arr.meanNumber().doubleValue(); - assertEquals(0, min, 0.5); - assertEquals(10, max, 0.5); - assertEquals(5.5, mean, 1); + if (t.equals(DataType.DOUBLE)) { + double min = arr.minNumber().doubleValue(); + double max = arr.maxNumber().doubleValue(); + double mean = arr.meanNumber().doubleValue(); + assertEquals(0, min, 0.5); + assertEquals(10, max, 0.5); + assertEquals(5.5, mean, 1); + } + else if (t.equals(DataType.FLOAT)) { + float min = arr.minNumber().floatValue(); + float max = arr.maxNumber().floatValue(); + float mean = arr.meanNumber().floatValue(); + assertEquals(0, min, 0.5); + assertEquals(10, max, 0.5); + assertEquals(5.0, mean, 1); + } } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 3027138a1..bb2287e03 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -235,39 +235,39 @@ public class ReductionOpValidation extends BaseOpValidation { tc.expectedOutput("loss", inputArr.normmax()); break; case 10: - loss = sd.math().countNonZero("loss", input); + loss = sd.math().countNonZero("loss", input, 0,1); name = "countNonZero"; tc.expectedOutput("loss", Nd4j.scalar(inputArr.length())); gradCheck = false; //Long out, not floating point break; case 11: - loss = sd.math().countZero("loss", input); + loss = sd.math().countZero("loss", input, 0,1); name = "countZero"; tc.expectedOutput("loss", Nd4j.scalar(0L)); gradCheck = false; //Long out, not floating point break; case 12: - loss = sd.math().amax("loss", input); + loss = sd.math().amax("loss", input, 0,1); name = "amax"; tc.expectedOutput("loss", inputArr.amax()); break; case 13: - loss = sd.math().amin("loss", input); + loss = sd.math().amin("loss", input, 0,1); name = "amin"; tc.expectedOutput("loss", inputArr.amin()); break; case 14: - loss = sd.math().asum("loss", input); + loss = sd.math().asum("loss", input, 0,1); name = "asum"; tc.expectedOutput("loss", Nd4j.getExecutioner().exec(new ASum(inputArr.dup()))); break; case 15: - loss = sd.math().amean("loss", input); + loss = sd.math().amean("loss", input, 0,1); name = "amean"; tc.expectedOutput("loss", Nd4j.getExecutioner().exec(new AMean(inputArr.dup()))); break; case 16: - loss = sd.math().entropy("loss", input); + loss = sd.math().entropy("loss", input, 0,1); name = "entropy"; inputArr = Nd4j.linspace(0.01, 0.99, length, DataType.DOUBLE).reshape('c', minibatch, nOut); tc.expected("loss", inputArr.mul(Transforms.log(inputArr, true)).sum(Integer.MAX_VALUE).negi()); @@ -290,14 +290,14 @@ public class ReductionOpValidation extends BaseOpValidation { case 19: inputArr = Nd4j.rand(minibatch, nOut); name = "logEntropy"; - loss = sd.math().logEntropy("loss", input); + loss = sd.math().logEntropy("loss", input, 0,1); double logEntropy = inputArr.logEntropyNumber().doubleValue(); tc.expected(loss, Nd4j.scalar(logEntropy)); break; case 20: inputArr = Nd4j.rand(minibatch, nOut); name = "shannonEntropy"; - loss = sd.math().shannonEntropy("loss", input); + loss = sd.math().shannonEntropy("loss", input, 0); double shannonEntropy = inputArr.shannonEntropyNumber().doubleValue(); tc.expected(loss, Nd4j.scalar(shannonEntropy)); if (OpValidationSuite.IGNORE_FAILING) { @@ -836,11 +836,11 @@ public class ReductionOpValidation extends BaseOpValidation { @Test public void testIndexAccum() { List failed = new ArrayList<>(); - List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]); + List dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1} /*, new int[0]*/); INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4); - for (int t = 0; t < 4; t++) { + for (int t = 0; t < 3; t++) { int[] d = dims.get(t); for (int i = 0; i < 7; i++) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 30d4baf5c..795cef3f1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1406,14 +1406,13 @@ public class ShapeOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable[] arr = new SDVariable[rank]; - List names = new ArrayList<>(); + String[] names = new String[rank]; for( int i=0; i ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 2, 4)); List outputs = Arrays.asList("in", "z", "softmax"); @@ -3522,13 +3521,13 @@ public class SameDiffTests extends BaseNd4jTest { @Test public void testRngSanityCheck(){ Nd4j.getRandom().setSeed(12345); - for(DataType dt : DataType.values()) { + for(DataType dt : new DataType[]{DataType.FLOAT, DataType.DOUBLE,DataType.BFLOAT16}) { if (!dt.isNumerical()) continue; SameDiff sameDiff = SameDiff.create(); INDArray indaShape = Nd4j.createFromArray(3, 10); SDVariable sdShape = sameDiff.constant(indaShape); - SDVariable random = sameDiff.random().uniform("data", 0.0, 10.0, sdShape, dt); + SDVariable random = sameDiff.random().uniform("data", 0.0, 10.0, dt, 3, 10); INDArray out = random.eval(); String s = out.toString(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 4db765c5e..13853f246 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -80,7 +80,7 @@ public class SameDiffTrainingTest extends BaseNd4jTest { SDVariable z0 = in.mmul(w0).add(b0); SDVariable a0 = sd.math().tanh(z0); SDVariable z1 = a0.mmul(w1).add("prediction", b1); - SDVariable a1 = sd.nn().softmax(z1); + SDVariable a1 = sd.nn().softmax(z1,-1); SDVariable diff = sd.f().squaredDifference(a1, label); SDVariable lossMse = diff.mul(diff).mean(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java index cb6c70a89..d9f942793 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ExecDebuggingListenerTest.java @@ -2,6 +2,7 @@ package org.nd4j.autodiff.samediff.listeners; import org.junit.Test; import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener; +import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java index b2c33f386..4f105aecc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java @@ -85,7 +85,7 @@ public class ListenerTest extends BaseNd4jTest { SDVariable z1 = a0.mmul(w1).add(b1); SDVariable predictions = sd.nn().softmax("predictions", z1, 1); - SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions); + SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions, null); sd.setLossVariables("loss"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 2d178a210..01dc83ee4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -34,11 +34,14 @@ import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; +import org.nd4j.linalg.api.ops.impl.transforms.Cholesky; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.Qr; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; @@ -1775,4 +1778,41 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32)); assertEquals(expected, ret[0]); } + + @Test + public void testCholesky() { + INDArray x = Nd4j.createFromArray(new double[] {4,12,-16, 12 ,37,-43, -16, -43, 98}).reshape(3,3); + INDArray exp = Nd4j.createFromArray(new double[] {2., 0., 0., 6., 1., 0., -8., 5., 3.}).reshape(3,3); + + INDArray[] res = Nd4j.exec(new Cholesky(x)); + assertEquals(res[0], exp); + } + + @Test + public void testQr() { + INDArray in = Nd4j.createFromArray(new double[]{ + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }).reshape(5,3); + Qr op = new Qr(in); + INDArray[] ret = Nd4j.exec(op); + INDArray res = Nd4j.createUninitialized(in.shape()); + DynamicCustomOp matmul = DynamicCustomOp.builder("matmul") + .addInputs(ret[0], ret[1]) + .build(); + ret = Nd4j.exec(matmul); + assertEquals(ret[0], in); + } + + @Test + public void testLogdet() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + + INDArray expected = Nd4j.createFromArray(new double[]{3.5835189, 4.159008}); + INDArray[] ret = Nd4j.exec(new Logdet(x)); + assertEquals(ret[0], expected); + + } + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java index 40d32121d..6d0ec8f54 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDLossTest.java @@ -61,7 +61,7 @@ public class NDLossTest extends BaseNd4jTest { SDVariable loss = sd.loss().absoluteDifference("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().absoluteDifference("loss2", labels, predictions,null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -251,8 +251,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray predictionsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); - SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction); + SDVariable loss = sd.loss().logPoisson("loss", labels, predictions, w, reduction, false); + SDVariable loss2 = sd.loss().logPoisson("loss2", labels, predictions, null, reduction, false); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -285,7 +285,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); SDVariable loss = sd.loss().meanPairwiseSquaredError("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().meanPairwiseSquaredError("loss2", labels, predictions, + null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -318,7 +319,8 @@ public class NDLossTest extends BaseNd4jTest { INDArray labelsArr = Nd4j.randn(DataType.DOUBLE, minibatch, nOut); SDVariable loss = sd.loss().meanSquaredError("loss", labels, predictions, w, reduction); - SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, null, reduction); + SDVariable loss2 = sd.loss().meanSquaredError("loss2", labels, predictions, + null, reduction); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -352,7 +354,8 @@ public class NDLossTest extends BaseNd4jTest { double labelSmoothing = 0.01; SDVariable loss = sd.loss().sigmoidCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); - SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); + SDVariable loss2 = sd.loss().sigmoidCrossEntropy("loss2", labels, predictions, + null, reduction, labelSmoothing); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); @@ -388,7 +391,7 @@ public class NDLossTest extends BaseNd4jTest { double labelSmoothing = 0.0; - SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, w, reduction, labelSmoothing); + SDVariable loss = sd.loss().softmaxCrossEntropy("loss", labels, predictions, null, reduction, labelSmoothing); SDVariable loss2 = sd.loss().softmaxCrossEntropy("loss2", labels, predictions, null, reduction, labelSmoothing); sd.associateArrayWithVariable(predictionsArr, predictions); sd.associateArrayWithVariable(labelsArr, labels); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java new file mode 100644 index 000000000..fbce0db6b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/generated/SDLinalgTest.java @@ -0,0 +1,285 @@ +/***************************************************************************** + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.generated; + +import org.junit.Before; +import org.junit.Test; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class SDLinalgTest extends BaseNd4jTest { + public SDLinalgTest(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering(){ + return 'c'; + } + + private SameDiff sameDiff; + + @Before + public void setup() { + sameDiff = SameDiff.create(); + } + + @Test + public void testCholesky() { + INDArray input = Nd4j.createFromArray( + new float[]{ + 10.f, 14.f, + 14.f, 20.f, + 74.f, 86.f, + 86.f, 100.f + } + ).reshape(2,2,2); + + INDArray expected = Nd4j.createFromArray( + new float[]{ + 3.1622777f, 0.f, 4.427189f, 0.6324552f, + 8.602325f, 0.f, 9.997296f, 0.23252854f + } + ).reshape(2,2,2); + + SDVariable sdinput = sameDiff.var(input); + SDVariable out = sameDiff.linalg().cholesky(sdinput); + assertEquals(expected, out.eval()); + } + + @Test + public void testLstsq() { + INDArray a = Nd4j.createFromArray(new float[]{ + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f + }).reshape(2,2,2); + + INDArray b = Nd4j.createFromArray(new float[]{ + 3.f, 7.f, 11.f, 15.f + }).reshape(2,2,1); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.831169367f, 1.090908766f, 0.920544624f, 1.063016534f + }).reshape(2,2,1); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().lstsq(sda,sdb,0.5,true); + assertEquals(expected, res.eval()); + } + + @Test + public void testLu() { + SDVariable sdInput = sameDiff.var(Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7. + }).reshape(3,3)); + + INDArray expected = Nd4j.createFromArray(new double[]{ + 1., 2., 3., 0., 2., 3., 0., 0., 7 + }).reshape(3,3); + + SDVariable out = sameDiff.linalg().lu("lu", sdInput); + assertEquals(expected, out.eval()); + } + + @Test + public void testMatrixBandPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + INDArray expected = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); + + SDVariable sdx = sameDiff.var(x); + SDVariable[] res = sameDiff.linalg().matrixBandPart(sdx, 1, 1); + assertArrayEquals(x.shape(), res[0].eval().shape()); + } + + @Test + public void testQr() { + INDArray input = Nd4j.createFromArray(new double[]{ + 12., -51., 4., + 6., 167., -68., + -4., 24., -41., + -1., 1., 0., + 2., 0., 3. + }).reshape(5,3); + + INDArray expectedQ = Nd4j.createFromArray(new double[]{ + 0.8464147390303179, -0.3912908119746455, 0.34312406418022884, + 0.42320736951515897, 0.9040872694197354, -0.02927016186366648, + -0.2821382463434393, 0.17042054976392634, 0.9328559865183932, + -0.07053456158585983, 0.01404065236547358, -0.00109937201747271, + 0.14106912317171966, -0.01665551070074392, -0.10577161246232346 + }).reshape(5,3); + + INDArray expectedR = Nd4j.createFromArray(new double[]{ + 14.177446878757824, 20.666626544656932, -13.401566701313369, + -0.0000000000000006, 175.04253925050244, -70.0803066408638, + 0.00000000000000017, -0.00000000000000881, -35.20154302119086 + }).reshape(3,3); + + SDVariable sdInput = sameDiff.var(input); + SDVariable[] res = sameDiff.linalg().qr(sdInput); + + SDVariable mmulResult = sameDiff.mmul(res[0], res[1]); + + assertEquals(input, mmulResult.eval()); + } + + @Test + public void testSolve() { + INDArray a = Nd4j.createFromArray(new float[] { + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }).reshape(3,3); + + INDArray b = Nd4j.createFromArray(new float[] { + 2.f, 4.f, 3.f + }).reshape(3,1); + + INDArray expected = Nd4j.createFromArray(new float[] { + 7.625f, 3.25f, 5.f + }).reshape(3,1); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().solve(sda, sdb); + assertEquals(expected, res.eval()); + } + + @Test + public void testTriangularSolve() { + INDArray a = Nd4j.createFromArray(new float[] { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }).reshape(3,3); + + INDArray b = Nd4j.createFromArray(new float[] { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }).reshape(3,3); + + INDArray expected = Nd4j.createFromArray(new float[] { + 0.99088347f, 1.1917052f, 1.2642528f, + 0.35071516f, 0.50630623f, 0.42935497f, + -0.30013534f, -0.53690606f, -0.47959247f + }).reshape(3,3); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().triangularSolve(sda, sdb, true, false); + assertEquals(expected, res.eval()); + } + + @Test + public void testCross() { + INDArray a = Nd4j.createFromArray(new double[]{1, 2, 3}); + INDArray b = Nd4j.createFromArray(new double[]{6, 7, 8}); + INDArray expected = Nd4j.createFromArray(new double[]{-5, 10, -5}); + + SDVariable sda = sameDiff.var(a); + SDVariable sdb = sameDiff.var(b); + + SDVariable res = sameDiff.linalg().cross(sda, sdb); + assertEquals(expected, res.eval()); + } + + @Test + public void testDiag() { + INDArray x = Nd4j.createFromArray(new double[]{1,2}); + INDArray expected = Nd4j.createFromArray(new double[]{1,0,0,2}).reshape(2,2); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().diag(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testDiagPart() { + INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4).reshape(2,2); + INDArray expected = Nd4j.createFromArray(new double[]{1,4}); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().diag_part(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testLogdet() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + INDArray expected = Nd4j.createFromArray(new double[]{3.5835189, 4.159008}); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().logdet(sdx); + assertEquals(expected, res.eval()); + } + + @Test + public void testSvd() { + INDArray x = Nd4j.createFromArray(new double[]{ + 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f + }).reshape(3,3); + INDArray expected = Nd4j.createFromArray(new double[]{1.8967269987492157, 0.3709665595850617, 0.05524869852188223}); + + SDVariable sdx = sameDiff.var(x); + SDVariable res = sameDiff.linalg().svd(sdx, false, false); + assertEquals(expected, res.eval()); + } + + @Test + public void testLogdetName() { + INDArray x = Nd4j.createFromArray(new double[]{ + 4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8 + }).reshape(2,3,3); + + SDVariable sdx = sameDiff.var(x); + + SDVariable res = sameDiff.linalg().logdet("logdet", sdx); + assertEquals("logdet", res.name()); + } + + @Test + public void testQrNames() { + INDArray input = Nd4j.createFromArray(new double[]{ + 12., -51., 4., + 6., 167., -68., + -4., 24., -41., + -1., 1., 0., + 2., 0., 3. + }).reshape(5,3); + + SDVariable sdInput = sameDiff.var(input); + SDVariable[] res = sameDiff.linalg().qr(new String[]{"ret0", "ret1"}, sdInput); + + assertEquals("ret0", res[0].name()); + assertEquals("ret1", res[1].name()); + } +}