Fix incompatibilities with generated code (#303)
* Cholesky fixed * Constructors added * MatMul wrapper * Constructor added * Missing wrappers added * Generate Linalg namespace added * Output data types * Unit tests * Added mmul * Code generation * Code generated * Build fixed * Fixing signatures * Tests fixed * Tests fixed * Added enum * Fix tests * Some fixes * Eye test fixed * SameDiff: small fix for renameVariable - also replace variable name in lossVariable list if necessary Signed-off-by: Alex Black <blacka101@gmail.com> * Some fixes * Tests fixed * Revert wrong fix * Some fixes * Some fixes * Extending base test class * Added pad * Fixed for generated signatures * Fixes due to nd4j codegen * Backwards compatibility fixes * Fixed errors in tests, reverted wrong changes * Test fixed * Added missing operations used for nd4s operators * Compilation fixed * Added meshgrid * Fixed constructors * fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad commit (incorrectly reverted change from master) Signed-off-by: Alex Black <blacka101@gmail.com> * Fixed test Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
1d004b542a
commit
0a27e9f41d
|
@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
|
||||
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
|
||||
SDVariable[] vars;
|
||||
if(hasBias){
|
||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||
vars = new SDVariable[]{layerInput, w, b};
|
||||
} else {
|
||||
vars = new SDVariable[]{layerInput, w};
|
||||
}
|
||||
|
||||
Conv2DConfig c = Conv2DConfig.builder()
|
||||
.kH(kernel[0]).kW(kernel[1])
|
||||
.pH(padding[0]).pW(padding[1])
|
||||
|
@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
.isSameMode(this.cm == ConvolutionMode.Same)
|
||||
.build();
|
||||
|
||||
SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name
|
||||
SDVariable conv = null;
|
||||
if(hasBias){
|
||||
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
|
||||
conv = sameDiff.cnn().conv2d(layerInput, w, b, c);
|
||||
} else {
|
||||
conv = sameDiff.cnn().conv2d(layerInput, w, c);
|
||||
}
|
||||
|
||||
return activation.asSameDiff("out", sameDiff, conv);
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
|
||||
// input: [mb, inputCapsules, inputCapsuleDimensions]
|
||||
|
||||
// [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
|
||||
SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4);
|
||||
SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);
|
||||
|
||||
// [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1]
|
||||
SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
|
||||
SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
|
||||
|
||||
// [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
|
||||
SDVariable weights = paramTable.get(WEIGHT_PARAM);
|
||||
|
@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer {
|
|||
|
||||
// b is the logits of the routing procedure
|
||||
// [mb, inputCapsules, capsules, 1, 1]
|
||||
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
|
||||
SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
|
||||
|
||||
for(int i = 0 ; i < routings ; i++){
|
||||
|
||||
// c is the coupling coefficient, i.e. the edge weight between the 2 capsules
|
||||
// [mb, inputCapsules, capsules, 1, 1]
|
||||
SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5);
|
||||
SDVariable c = sd.nn.softmax(b, 2);
|
||||
|
||||
// [mb, 1, capsules, capsuleDimensions, 1]
|
||||
SDVariable s = c.times(uHat).sum(true, 1);
|
||||
|
@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer {
|
|||
|
||||
// v is the per capsule activations. On the last routing iteration, this is output
|
||||
// [mb, 1, capsules, capsuleDimensions, 1]
|
||||
SDVariable v = CapsuleUtils.squash(SD, s, 3);
|
||||
SDVariable v = CapsuleUtils.squash(sd, s, 3);
|
||||
|
||||
if(i == routings - 1){
|
||||
return SD.squeeze(SD.squeeze(v, 1), 3);
|
||||
return sd.squeeze(sd.squeeze(v, 1), 3);
|
||||
}
|
||||
|
||||
// [mb, inputCapsules, capsules, capsuleDimensions, 1]
|
||||
SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1);
|
||||
SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);
|
||||
|
||||
// [mb, inputCapsules, capsules, 1, 1]
|
||||
b = b.plus(uHat.times(vTiled).sum(true, 3));
|
||||
|
|
|
@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer {
|
|||
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
||||
//NCW format.
|
||||
if(cm == ConvolutionMode.Same) {
|
||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0);
|
||||
layerInput = sameDiff.nn().pad(layerInput,
|
||||
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0);
|
||||
} else {
|
||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0);
|
||||
layerInput = sameDiff.nn().pad(layerInput,
|
||||
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer {
|
|||
//Note: for same mode, bottom/right padding can be 1 more than top/left padding
|
||||
//NCHW format
|
||||
if(cm == ConvolutionMode.Same){
|
||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0);
|
||||
layerInput = sameDiff.nn().pad(layerInput,
|
||||
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0);
|
||||
} else {
|
||||
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0);
|
||||
layerInput = sameDiff.nn().pad(layerInput,
|
||||
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -45,15 +45,4 @@ public class CapsuleUtils {
|
|||
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute softmax along a given dimension
|
||||
*/
|
||||
public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){
|
||||
int[] permutation = ArrayUtil.range(0, rank);
|
||||
permutation[0] = dimension;
|
||||
permutation[dimension] = 0;
|
||||
|
||||
return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
|
|||
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
|
||||
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
|
||||
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
|
||||
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
|
||||
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1);
|
||||
|
||||
val server = new JsonModelServer.Builder<float[], Integer>(sd)
|
||||
.outputSerializer( new IntSerde())
|
||||
|
|
|
@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest {
|
|||
SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
|
||||
|
||||
SDVariable z = in.mmul(w).add(b);
|
||||
SDVariable a = sd.nn().tanh(z);
|
||||
SDVariable a = sd.math().tanh(z);
|
||||
|
||||
LogFileWriter lfw = new LogFileWriter(f);
|
||||
lfw.writeGraphStructure(sd);
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|||
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||||
import org.deeplearning4j.integration.ModelType;
|
||||
import org.deeplearning4j.integration.TestCase;
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.commons.lang3.ArrayUtils;
|
|||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.enums.DataFormat;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -1489,7 +1490,7 @@ public class DifferentialFunctionFactory {
|
|||
}
|
||||
|
||||
public SDVariable reciprocal(SDVariable a) {
|
||||
return new Reciprocal(sameDiff(), a, false).outputVariable();
|
||||
return new Reciprocal(sameDiff(), a).outputVariable();
|
||||
}
|
||||
|
||||
|
||||
|
@ -1990,13 +1991,13 @@ public class DifferentialFunctionFactory {
|
|||
.outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, String dataFormat) {
|
||||
public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat)
|
||||
.outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, String dataFormat) {
|
||||
public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat)
|
||||
.outputVariable();
|
||||
|
@ -2635,7 +2636,7 @@ public class DifferentialFunctionFactory {
|
|||
return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) {
|
||||
public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) {
|
||||
return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables();
|
||||
}
|
||||
|
||||
|
|
|
@ -181,6 +181,11 @@ public class SameDiff extends SDBaseOps {
|
|||
*/
|
||||
public final SDBitwise bitwise = new SDBitwise(this);
|
||||
|
||||
/**
|
||||
* Op creator object for linalg operations
|
||||
*/
|
||||
public final SDLinalg linalg = new SDLinalg(this);
|
||||
|
||||
/**
|
||||
* Op creator object for math operations
|
||||
*/
|
||||
|
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
|
|||
return bitwise;
|
||||
}
|
||||
|
||||
/**
|
||||
* Op creator object for linalg operations
|
||||
*/
|
||||
public SDLinalg linalg(){
|
||||
return linalg;
|
||||
}
|
||||
|
||||
private Map<String, SameDiff> sameDiffFunctionInstances;
|
||||
|
||||
private Table<String, String, String> fieldVariableResolutionMapping;
|
||||
|
@ -3448,6 +3460,12 @@ public class SameDiff extends SDBaseOps {
|
|||
sd.renameVariable(from, to);
|
||||
}
|
||||
}
|
||||
|
||||
//Check losses:
|
||||
if(lossVariables.contains(from)){
|
||||
int idx = lossVariables.indexOf(from);
|
||||
lossVariables.set(idx, to);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,217 +1,416 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
public class SDBitwise extends SDOps {
|
||||
public SDBitwise(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public class SDBitwise extends SDOps {
|
||||
public SDBitwise(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
/**
|
||||
* Bitwise AND operation. Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param x First input array (INT type)
|
||||
* @param y Second input array (INT type)
|
||||
* @return output Bitwise AND array (INT type)
|
||||
*/
|
||||
public SDVariable and(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("and", "x", x);
|
||||
SDValidation.validateInteger("and", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #leftShift(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){
|
||||
return leftShift(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise AND operation. Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input array (INT type)
|
||||
* @param y Second input array (INT type)
|
||||
* @return output Bitwise AND array (INT type)
|
||||
*/
|
||||
public SDVariable and(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("and", "x", x);
|
||||
SDValidation.validateInteger("and", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise left shift operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise shifted input x
|
||||
*/
|
||||
public SDVariable leftShift(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise left shift", x);
|
||||
validateInteger("bitwise left shift", y);
|
||||
/**
|
||||
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
|
||||
*
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitRotl(SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitRotl", "x", x);
|
||||
SDValidation.validateInteger("bitRotl", "shift", shift);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
|
||||
}
|
||||
|
||||
SDVariable ret = f().shift(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitRotl", "x", x);
|
||||
SDValidation.validateInteger("bitRotl", "shift", shift);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #rightShift(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable rightShift(SDVariable x, SDVariable y){
|
||||
return rightShift(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
|
||||
*
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitRotr(SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitRotr", "x", x);
|
||||
SDValidation.validateInteger("bitRotr", "shift", shift);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise right shift operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise shifted input x
|
||||
*/
|
||||
public SDVariable rightShift(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise right shift", x);
|
||||
validateInteger("bitwise right shift", y);
|
||||
/**
|
||||
* Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitRotr", "x", x);
|
||||
SDValidation.validateInteger("bitRotr", "shift", shift);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
SDVariable ret = f().rshift(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Shift integer bits to the left, i.e. var << 4<br>
|
||||
*
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitShift(SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitShift", "x", x);
|
||||
SDValidation.validateInteger("bitShift", "shift", shift);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){
|
||||
return leftShiftCyclic(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Shift integer bits to the left, i.e. var << 4<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitShift", "x", x);
|
||||
SDValidation.validateInteger("bitShift", "shift", shift);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.
|
||||
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise cyclic shifted input x
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise left shift (cyclic)", x);
|
||||
validateInteger("bitwise left shift (cyclic)", y);
|
||||
/**
|
||||
* Shift integer bits to the right, i.e. var >> 4<br>
|
||||
*
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitShiftRight(SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitShiftRight", "x", x);
|
||||
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
|
||||
}
|
||||
|
||||
SDVariable ret = f().rotl(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Shift integer bits to the right, i.e. var >> 4<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input 1 (INT type)
|
||||
* @param shift Number of bits to shift. (INT type)
|
||||
* @return output SDVariable with shifted bits (INT type)
|
||||
*/
|
||||
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
|
||||
SDValidation.validateInteger("bitShiftRight", "x", x);
|
||||
SDValidation.validateInteger("bitShiftRight", "shift", shift);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){
|
||||
return rightShiftCyclic(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
*
|
||||
* @param x First input array. (INT type)
|
||||
* @param y Second input array. (INT type)
|
||||
* @return output bitwise Hamming distance (INT type)
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
||||
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.
|
||||
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around":
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x Input to be bit shifted (must be an integer type)
|
||||
* @param y Amount to shift elements of x array (must be an integer type)
|
||||
* @return Bitwise cyclic shifted input x
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise right shift (cyclic)", x);
|
||||
validateInteger("bitwise right shift (cyclic)", y);
|
||||
/**
|
||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input array. (INT type)
|
||||
* @param y Second input array. (INT type)
|
||||
* @return output bitwise Hamming distance (INT type)
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("bitsHammingDistance", "x", x);
|
||||
SDValidation.validateInteger("bitsHammingDistance", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
SDVariable ret = f().rotr(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Bitwise left shift operation. Supports broadcasting.<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable leftShift(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("leftShift", "x", x);
|
||||
SDValidation.validateInteger("leftShift", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){
|
||||
return bitsHammingDistance(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise left shift operation. Supports broadcasting.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable leftShift(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("leftShift", "x", x);
|
||||
SDValidation.validateInteger("leftShift", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise Hamming distance reduction over all elements of both input arrays.<br>
|
||||
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return
|
||||
*/
|
||||
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise hamming distance", x);
|
||||
validateInteger("bitwise hamming distance", y);
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise cyclic shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
||||
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
SDVariable ret = f().bitwiseHammingDist(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise cyclic shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("leftShiftCyclic", "x", x);
|
||||
SDValidation.validateInteger("leftShiftCyclic", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #and(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable and(SDVariable x, SDVariable y){
|
||||
return and(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise OR operation. Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param x First input array (INT type)
|
||||
* @param y First input array (INT type)
|
||||
* @return output Bitwise OR array (INT type)
|
||||
*/
|
||||
public SDVariable or(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("or", "x", x);
|
||||
SDValidation.validateInteger("or", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise AND operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise AND array
|
||||
*/
|
||||
public SDVariable and(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise AND", x);
|
||||
validateInteger("bitwise AND", y);
|
||||
/**
|
||||
* Bitwise OR operation. Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input array (INT type)
|
||||
* @param y First input array (INT type)
|
||||
* @return output Bitwise OR array (INT type)
|
||||
*/
|
||||
public SDVariable or(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("or", "x", x);
|
||||
SDValidation.validateInteger("or", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
SDVariable ret = f().bitwiseAnd(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Bitwise right shift operation. Supports broadcasting. <br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable rightShift(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("rightShift", "x", x);
|
||||
SDValidation.validateInteger("rightShift", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #or(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable or(SDVariable x, SDVariable y){
|
||||
return or(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise right shift operation. Supports broadcasting. <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable rightShift(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("rightShift", "x", x);
|
||||
SDValidation.validateInteger("rightShift", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise OR operation. Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise OR array
|
||||
*/
|
||||
public SDVariable or(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise OR", x);
|
||||
validateInteger("bitwise OR", y);
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise cyclic shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
||||
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
SDVariable ret = f().bitwiseOr(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
* @param y Amount to shift elements of x array (INT type)
|
||||
* @return output Bitwise cyclic shifted input x (INT type)
|
||||
*/
|
||||
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("rightShiftCyclic", "x", x);
|
||||
SDValidation.validateInteger("rightShiftCyclic", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #xor(String, SDVariable, SDVariable)}
|
||||
*/
|
||||
public SDVariable xor(SDVariable x, SDVariable y){
|
||||
return xor(null, x, y);
|
||||
}
|
||||
/**
|
||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param x First input array (INT type)
|
||||
* @param y First input array (INT type)
|
||||
* @return output Bitwise XOR array (INT type)
|
||||
*/
|
||||
public SDVariable xor(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("xor", "x", x);
|
||||
SDValidation.validateInteger("xor", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.
|
||||
*
|
||||
* @param name Name of the output variable. May be null.
|
||||
* @param x First input array. Must be integer type.
|
||||
* @param y First input array. Must be integer type, same type as x
|
||||
* @return Bitwise XOR array
|
||||
*/
|
||||
public SDVariable xor(String name, SDVariable x, SDVariable y){
|
||||
validateInteger("bitwise XOR", x);
|
||||
validateInteger("bitwise XOR", y);
|
||||
|
||||
SDVariable ret = f().bitwiseXor(x, y);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Flip bits
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param x input array
|
||||
* @return array after flipping each input bit
|
||||
*/
|
||||
public SDVariable toggleBits(String name, SDVariable x) {
|
||||
SDVariable res = f().toggleBits(x);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
/**
|
||||
* Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be same types: isSameType(x, y)<br>
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input array (INT type)
|
||||
* @param y First input array (INT type)
|
||||
* @return output Bitwise XOR array (INT type)
|
||||
*/
|
||||
public SDVariable xor(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateInteger("xor", "x", x);
|
||||
SDValidation.validateInteger("xor", "y", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,185 +1,440 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
/**
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SDImage extends SDOps {
|
||||
public SDImage(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
public SDImage(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.
|
||||
*
|
||||
* @param name May be null. Name for the output variable.
|
||||
* @param image Input image, with shape [batch, height, width, channels]
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes]
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth]
|
||||
* @param method Image resize method
|
||||
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
|
||||
* @return Cropped and resized images
|
||||
*/
|
||||
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize,
|
||||
CropAndResize.Method method, double extrapolationValue) {
|
||||
SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
|
||||
*
|
||||
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
|
||||
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
|
||||
* @return output Cropped and resized images (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
|
||||
SDVariable cropOutSize, double extrapolationValue) {
|
||||
SDValidation.validateNumerical("CropAndResize", "image", image);
|
||||
SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
|
||||
SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
|
||||
SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension.
|
||||
*
|
||||
* @param name Map be null. Name for the output variable
|
||||
* @param image Input image to extract image patches from - shape [batch, height, width, channels]
|
||||
* @param kSizes Kernel size - size of the image patches, [height, width]
|
||||
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width]
|
||||
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
|
||||
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
|
||||
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension
|
||||
* @param sameMode Padding algorithm. If true: use Same padding
|
||||
* @return The extracted image patches
|
||||
*/
|
||||
public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes,
|
||||
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode) {
|
||||
SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
|
||||
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
|
||||
* @return output Cropped and resized images (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes,
|
||||
SDVariable boxIndices, SDVariable cropOutSize, double extrapolationValue) {
|
||||
SDValidation.validateNumerical("CropAndResize", "image", image);
|
||||
SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
|
||||
SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
|
||||
SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score
|
||||
* @param name Might be null. Name for the output variable
|
||||
* @param boxes 2D array of shape [num_boxes,4]
|
||||
* @param scores vector of shape [num_boxes]
|
||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||
* @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU
|
||||
* @param scoreThreshold float - threshold for deciding when to remove boxes based on score
|
||||
* @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size
|
||||
*/
|
||||
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){
|
||||
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
|
||||
*
|
||||
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
|
||||
* @return output Cropped and resized images (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
|
||||
SDVariable cropOutSize) {
|
||||
SDValidation.validateNumerical("CropAndResize", "image", image);
|
||||
SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
|
||||
SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
|
||||
SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjusts contrast of RGB or grayscale images.
|
||||
* @param name name for the output variable
|
||||
* @param in images to adjust. 3D shape or higher.
|
||||
* @param factor float multiplier for adjusting contrast.
|
||||
* @return Contrast-adjusted image
|
||||
*/
|
||||
public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
SDVariable out = new AdjustContrast(sd, in, factor).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
|
||||
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
|
||||
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
|
||||
* @return output Cropped and resized images (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes,
|
||||
SDVariable boxIndices, SDVariable cropOutSize) {
|
||||
SDValidation.validateNumerical("CropAndResize", "image", image);
|
||||
SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
|
||||
SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
|
||||
SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjust saturation of RGB images
|
||||
* @param name name for the output variable
|
||||
* @param in RGB image as 3D array
|
||||
* @param factor factor for saturation
|
||||
* @return adjusted image
|
||||
*/
|
||||
public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjusts contrast of RGB or grayscale images.<br>
|
||||
*
|
||||
* @param in images to adjust. 3D shape or higher (NUMERIC type)
|
||||
* @param factor multiplier for adjusting contrast
|
||||
* @return output Contrast-adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustContrast(SDVariable in, double factor) {
|
||||
SDValidation.validateNumerical("adjustContrast", "in", in);
|
||||
return new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adjust hue of RGB image
|
||||
* @param name name for the output variable
|
||||
* @param in RGB image as 3D array
|
||||
* @param delta value to add to hue channel
|
||||
* @return adjusted image
|
||||
*/
|
||||
public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) {
|
||||
SDVariable out = new AdjustHue(sd, in, delta).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjusts contrast of RGB or grayscale images.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param in images to adjust. 3D shape or higher (NUMERIC type)
|
||||
* @param factor multiplier for adjusting contrast
|
||||
* @return output Contrast-adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustContrast(String name, SDVariable in, double factor) {
|
||||
SDValidation.validateNumerical("adjustContrast", "in", in);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Randomly crops image
|
||||
* @param name name for the output variable
|
||||
* @param input input array
|
||||
* @param shape shape for crop
|
||||
* @return cropped array
|
||||
*/
|
||||
public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) {
|
||||
SDVariable out = new RandomCrop(sd, input, shape).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjust hue of RGB image <br>
|
||||
*
|
||||
* @param in image as 3D array (NUMERIC type)
|
||||
* @param delta value to add to hue channel
|
||||
* @return output adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustHue(SDVariable in, double delta) {
|
||||
SDValidation.validateNumerical("adjustHue", "in", in);
|
||||
return new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from HSV to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable rgbToHsv(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new RgbToHsv(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjust hue of RGB image <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param in image as 3D array (NUMERIC type)
|
||||
* @param delta value to add to hue channel
|
||||
* @return output adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustHue(String name, SDVariable in, double delta) {
|
||||
SDValidation.validateNumerical("adjustHue", "in", in);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from HSV to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable hsvToRgb(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new HsvToRgb(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjust saturation of RGB images<br>
|
||||
*
|
||||
* @param in RGB image as 3D array (NUMERIC type)
|
||||
* @param factor factor for saturation
|
||||
* @return output adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustSaturation(SDVariable in, double factor) {
|
||||
SDValidation.validateNumerical("adjustSaturation", "in", in);
|
||||
return new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YIQ format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable rgbToYiq(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new RgbToYiq(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Adjust saturation of RGB images<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param in RGB image as 3D array (NUMERIC type)
|
||||
* @param factor factor for saturation
|
||||
* @return output adjusted image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable adjustSaturation(String name, SDVariable in, double factor) {
|
||||
SDValidation.validateNumerical("adjustSaturation", "in", in);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YIQ to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable yiqToRgb(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new YiqToRgb(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. <br>
|
||||
*
|
||||
* @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2))
|
||||
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2))
|
||||
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
|
||||
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
|
||||
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0))
|
||||
* @param sameMode Padding algorithm. If true: use Same padding
|
||||
* @return output The extracted image patches (NUMERIC type)
|
||||
*/
|
||||
public SDVariable extractImagePatches(SDVariable image, int[] kSizes, int[] strides, int[] rates,
|
||||
boolean sameMode) {
|
||||
SDValidation.validateNumerical("extractImagePatches", "image", image);
|
||||
Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length);
|
||||
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
|
||||
Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YUV format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable rgbToYuv(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new RgbToYuv(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type)
|
||||
* @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2))
|
||||
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2))
|
||||
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
|
||||
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
|
||||
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0))
|
||||
* @param sameMode Padding algorithm. If true: use Same padding
|
||||
* @return output The extracted image patches (NUMERIC type)
|
||||
*/
|
||||
public SDVariable extractImagePatches(String name, SDVariable image, int[] kSizes, int[] strides,
|
||||
int[] rates, boolean sameMode) {
|
||||
SDValidation.validateNumerical("extractImagePatches", "image", image);
|
||||
Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length);
|
||||
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
|
||||
Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YUV to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable yuvToRgb(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new YuvToRgb(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
/**
|
||||
* Converting image from HSV to RGB format <br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable hsvToRgb(SDVariable input) {
|
||||
SDValidation.validateNumerical("hsvToRgb", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from HSV to RGB format <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable hsvToRgb(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("hsvToRgb", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score<br>
|
||||
*
|
||||
* @param boxes Might be null. Name for the output variable (NUMERIC type)
|
||||
* @param scores vector of shape [num_boxes] (NUMERIC type)
|
||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
||||
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
||||
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
|
||||
*/
|
||||
public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize,
|
||||
double iouThreshold, double scoreThreshold) {
|
||||
SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes);
|
||||
SDValidation.validateNumerical("nonMaxSuppression", "scores", scores);
|
||||
return new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Greedily selects a subset of bounding boxes in descending order of score<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param boxes Might be null. Name for the output variable (NUMERIC type)
|
||||
* @param scores vector of shape [num_boxes] (NUMERIC type)
|
||||
* @param maxOutSize scalar representing the maximum number of boxes to be selected
|
||||
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
|
||||
* @param scoreThreshold threshold for deciding when to remove boxes based on score
|
||||
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
|
||||
*/
|
||||
public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores,
|
||||
int maxOutSize, double iouThreshold, double scoreThreshold) {
|
||||
SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes);
|
||||
SDValidation.validateNumerical("nonMaxSuppression", "scores", scores);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Randomly crops image<br>
|
||||
*
|
||||
* @param input input array (NUMERIC type)
|
||||
* @param shape shape for crop (INT type)
|
||||
* @return output cropped array (NUMERIC type)
|
||||
*/
|
||||
public SDVariable randomCrop(SDVariable input, SDVariable shape) {
|
||||
SDValidation.validateNumerical("randomCrop", "input", input);
|
||||
SDValidation.validateInteger("randomCrop", "shape", shape);
|
||||
return new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Randomly crops image<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input input array (NUMERIC type)
|
||||
* @param shape shape for crop (INT type)
|
||||
* @return output cropped array (NUMERIC type)
|
||||
*/
|
||||
public SDVariable randomCrop(String name, SDVariable input, SDVariable shape) {
|
||||
SDValidation.validateNumerical("randomCrop", "input", input);
|
||||
SDValidation.validateInteger("randomCrop", "shape", shape);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from HSV to RGB format<br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToHsv(SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToHsv", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from HSV to RGB format<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToHsv(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToHsv", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YIQ format <br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToYiq(SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToYiq", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YIQ format <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToYiq(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToYiq", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YUV format <br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToYuv(SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToYuv", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from RGB to YUV format <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable rgbToYuv(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("rgbToYuv", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YIQ to RGB format <br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable yiqToRgb(SDVariable input) {
|
||||
SDValidation.validateNumerical("yiqToRgb", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YIQ to RGB format <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable yiqToRgb(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("yiqToRgb", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YUV to RGB format <br>
|
||||
*
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable yuvToRgb(SDVariable input) {
|
||||
SDValidation.validateNumerical("yuvToRgb", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from YUV to RGB format <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input 3D image (NUMERIC type)
|
||||
* @return output 3D image (NUMERIC type)
|
||||
*/
|
||||
public SDVariable yuvToRgb(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("yuvToRgb", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,561 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
public class SDLinalg extends SDOps {
|
||||
public SDLinalg(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the Cholesky decomposition of one or more square matrices.<br>
|
||||
*
|
||||
* @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type)
|
||||
* @return output Transformed tensor (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cholesky(SDVariable input) {
|
||||
SDValidation.validateNumerical("Cholesky", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the Cholesky decomposition of one or more square matrices.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type)
|
||||
* @return output Transformed tensor (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cholesky(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("Cholesky", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for linear squares problems.<br>
|
||||
*
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param l2_reguralizer regularizer
|
||||
* @param fast fast mode, defaults to True
|
||||
* @return output Transformed tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer, boolean fast) {
|
||||
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
|
||||
return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for linear squares problems.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param l2_reguralizer regularizer
|
||||
* @param fast fast mode, defaults to True
|
||||
* @return output Transformed tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer,
|
||||
boolean fast) {
|
||||
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for linear squares problems.<br>
|
||||
*
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param l2_reguralizer regularizer
|
||||
* @return output Transformed tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer) {
|
||||
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
|
||||
return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for linear squares problems.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param l2_reguralizer regularizer
|
||||
* @return output Transformed tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer) {
|
||||
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes LU decomposition.<br>
|
||||
*
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lu(SDVariable input) {
|
||||
SDValidation.validateNumerical("Lu", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes LU decomposition.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable lu(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("Lu", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs matrix mutiplication on input tensors.<br>
|
||||
*
|
||||
* @param a input tensor (NUMERIC type)
|
||||
* @param b input tensor (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable matmul(SDVariable a, SDVariable b) {
|
||||
SDValidation.validateNumerical("Matmul", "a", a);
|
||||
SDValidation.validateNumerical("Matmul", "b", b);
|
||||
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs matrix mutiplication on input tensors.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param a input tensor (NUMERIC type)
|
||||
* @param b input tensor (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable matmul(String name, SDVariable a, SDVariable b) {
|
||||
SDValidation.validateNumerical("Matmul", "a", a);
|
||||
SDValidation.validateNumerical("Matmul", "b", b);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a tensor setting outside a central band in each innermost matrix.<br>
|
||||
*
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @param minLower lower diagonal count
|
||||
* @param maxUpper upper diagonal count
|
||||
*/
|
||||
public SDVariable[] matrixBandPart(SDVariable input, int minLower, int maxUpper) {
|
||||
SDValidation.validateNumerical("MatrixBandPart", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables();
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a tensor setting outside a central band in each innermost matrix.<br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @param minLower lower diagonal count
|
||||
* @param maxUpper upper diagonal count
|
||||
*/
|
||||
public SDVariable[] matrixBandPart(String[] names, SDVariable input, int minLower, int maxUpper) {
|
||||
SDValidation.validateNumerical("MatrixBandPart", "input", input);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables();
|
||||
return sd.updateVariableNamesAndReferences(out, names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the QR decompositions of input matrix.<br>
|
||||
*
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @param full full matrices mode
|
||||
*/
|
||||
public SDVariable[] qr(SDVariable input, boolean full) {
|
||||
SDValidation.validateNumerical("Qr", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the QR decompositions of input matrix.<br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
* @param input input tensor (NUMERIC type)
|
||||
* @param full full matrices mode
|
||||
*/
|
||||
public SDVariable[] qr(String[] names, SDVariable input, boolean full) {
|
||||
SDValidation.validateNumerical("Qr", "input", input);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables();
|
||||
return sd.updateVariableNamesAndReferences(out, names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the QR decompositions of input matrix.<br>
|
||||
*
|
||||
* @param input input tensor (NUMERIC type)
|
||||
*/
|
||||
public SDVariable[] qr(SDVariable input) {
|
||||
SDValidation.validateNumerical("Qr", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the QR decompositions of input matrix.<br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
* @param input input tensor (NUMERIC type)
|
||||
*/
|
||||
public SDVariable[] qr(String[] names, SDVariable input) {
|
||||
SDValidation.validateNumerical("Qr", "input", input);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables();
|
||||
return sd.updateVariableNamesAndReferences(out, names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear equations.<br>
|
||||
*
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param adjoint adjoint mode, defaults to False
|
||||
* @return output Output tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable solve(SDVariable matrix, SDVariable rhs, boolean adjoint) {
|
||||
SDValidation.validateNumerical("Solve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Solve", "rhs", rhs);
|
||||
return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear equations.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param adjoint adjoint mode, defaults to False
|
||||
* @return output Output tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable solve(String name, SDVariable matrix, SDVariable rhs, boolean adjoint) {
|
||||
SDValidation.validateNumerical("Solve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Solve", "rhs", rhs);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear equations.<br>
|
||||
*
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @return output Output tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable solve(SDVariable matrix, SDVariable rhs) {
|
||||
SDValidation.validateNumerical("Solve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Solve", "rhs", rhs);
|
||||
return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear equations.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @return output Output tensor (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable solve(String name, SDVariable matrix, SDVariable rhs) {
|
||||
SDValidation.validateNumerical("Solve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("Solve", "rhs", rhs);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear questions.<br>
|
||||
*
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param lower defines whether innermost matrices in matrix are lower or upper triangular
|
||||
* @param adjoint adjoint mode
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable triangularSolve(SDVariable matrix, SDVariable rhs, boolean lower,
|
||||
boolean adjoint) {
|
||||
SDValidation.validateNumerical("TriangularSolve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("TriangularSolve", "rhs", rhs);
|
||||
return new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Solver for systems of linear questions.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param matrix input tensor (NUMERIC type)
|
||||
* @param rhs input tensor (NUMERIC type)
|
||||
* @param lower defines whether innermost matrices in matrix are lower or upper triangular
|
||||
* @param adjoint adjoint mode
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable triangularSolve(String name, SDVariable matrix, SDVariable rhs, boolean lower,
|
||||
boolean adjoint) {
|
||||
SDValidation.validateNumerical("TriangularSolve", "matrix", matrix);
|
||||
SDValidation.validateNumerical("TriangularSolve", "rhs", rhs);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes pairwise cross product.<br>
|
||||
*
|
||||
* @param a (NUMERIC type)
|
||||
* @param b (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable cross(SDVariable a, SDVariable b) {
|
||||
SDValidation.validateNumerical("cross", "a", a);
|
||||
SDValidation.validateNumerical("cross", "b", b);
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes pairwise cross product.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param a (NUMERIC type)
|
||||
* @param b (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable cross(String name, SDVariable a, SDVariable b) {
|
||||
SDValidation.validateNumerical("cross", "a", a);
|
||||
SDValidation.validateNumerical("cross", "b", b);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates diagonal tensor.<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable diag(SDVariable input) {
|
||||
SDValidation.validateNumerical("diag", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates diagonal tensor.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable diag(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("diag", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates diagonal tensor.<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable diag_part(SDVariable input) {
|
||||
SDValidation.validateNumerical("diag_part", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates diagonal tensor.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable diag_part(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("diag_part", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates log of determinant.<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable logdet(SDVariable input) {
|
||||
SDValidation.validateNumerical("logdet", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates log of determinant.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable logdet(String name, SDVariable input) {
|
||||
SDValidation.validateNumerical("logdet", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication: out = mmul(x,y)<br>
|
||||
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
|
||||
*
|
||||
* @param x First input variable (NUMERIC type)
|
||||
* @param y Second input variable (NUMERIC type)
|
||||
* @param transposeX Transpose x (first argument)
|
||||
* @param transposeY Transpose y (second argument)
|
||||
* @param transposeZ Transpose result array
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
||||
boolean transposeZ) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication: out = mmul(x,y)<br>
|
||||
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input variable (NUMERIC type)
|
||||
* @param y Second input variable (NUMERIC type)
|
||||
* @param transposeX Transpose x (first argument)
|
||||
* @param transposeY Transpose y (second argument)
|
||||
* @param transposeZ Transpose result array
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX,
|
||||
boolean transposeY, boolean transposeZ) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication: out = mmul(x,y)<br>
|
||||
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
|
||||
*
|
||||
* @param x First input variable (NUMERIC type)
|
||||
* @param y Second input variable (NUMERIC type)
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(SDVariable x, SDVariable y) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrix multiplication: out = mmul(x,y)<br>
|
||||
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x First input variable (NUMERIC type)
|
||||
* @param y Second input variable (NUMERIC type)
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(String name, SDVariable x, SDVariable y) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates singular value decomposition.<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @param fullUV
|
||||
* @param computeUV
|
||||
* @param switchNum
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV, int switchNum) {
|
||||
SDValidation.validateNumerical("svd", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates singular value decomposition.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @param fullUV
|
||||
* @param computeUV
|
||||
* @param switchNum
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV,
|
||||
int switchNum) {
|
||||
SDValidation.validateNumerical("svd", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates singular value decomposition.<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @param fullUV
|
||||
* @param computeUV
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV) {
|
||||
SDValidation.validateNumerical("svd", "input", input);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates singular value decomposition.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @param fullUV
|
||||
* @param computeUV
|
||||
* @return output (FLOATING_POINT type)
|
||||
*/
|
||||
public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV) {
|
||||
SDValidation.validateNumerical("svd", "input", input);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -27,17 +27,21 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
*/
|
||||
public abstract class SDOps {
|
||||
|
||||
protected final SameDiff sd;
|
||||
protected final SameDiff sd;
|
||||
|
||||
public SDOps(SameDiff sameDiff) {
|
||||
this.sd = sameDiff;
|
||||
}
|
||||
public SDOps() {
|
||||
sd = null;
|
||||
}
|
||||
|
||||
protected DifferentialFunctionFactory f() {
|
||||
return sd.f();
|
||||
}
|
||||
public SDOps(SameDiff sameDiff) {
|
||||
this.sd = sameDiff;
|
||||
}
|
||||
|
||||
protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
|
||||
return sd.updateVariableNameAndReference(varToUpdate, newVarName);
|
||||
}
|
||||
protected DifferentialFunctionFactory f() {
|
||||
return sd.f();
|
||||
}
|
||||
|
||||
protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
|
||||
return sd.updateVariableNameAndReference(varToUpdate, newVarName);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,198 +14,232 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import java.lang.String;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
/**
|
||||
* SameDiff Recurrent Neural Network operations<br>
|
||||
* Accessible via {@link SameDiff#rnn()}<br>
|
||||
* See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.<br>
|
||||
* See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.<br>
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SDRNN extends SDOps {
|
||||
public SDRNN(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
public SDRNN(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
/**
|
||||
* The GRU cell. Does a single time step operation<br>
|
||||
*
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param GRUWeights Configuration Object
|
||||
* @return output The cell's outputs. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
||||
SDValidation.validateNumerical("gru", "x", x);
|
||||
SDValidation.validateNumerical("gru", "hLast", hLast);
|
||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}.
|
||||
*/
|
||||
public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
|
||||
GRUCell c = new GRUCell(sd, x, hLast, weights);
|
||||
return new GRUCellOutputs(c.outputVariables());
|
||||
}
|
||||
/**
|
||||
* The GRU cell. Does a single time step operation<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param GRUWeights Configuration Object
|
||||
* @return output The cell's outputs. (NUMERIC type)
|
||||
*/
|
||||
public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
|
||||
SDValidation.validateNumerical("gru", "x", x);
|
||||
SDValidation.validateNumerical("gru", "hLast", hLast);
|
||||
GRUCell c = new GRUCell(sd,x, hLast, GRUWeights);
|
||||
return new GRUCellOutputs(c.outputVariables(name));
|
||||
}
|
||||
|
||||
/**
|
||||
* The GRU cell. Does a single time step operation.
|
||||
*
|
||||
* @param baseName The base name for the gru cell
|
||||
* @param x Input, with shape [batchSize, inSize]
|
||||
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits]
|
||||
* @param weights The cell's weights.
|
||||
* @return The cell's outputs.
|
||||
*/
|
||||
public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) {
|
||||
GRUCell c = new GRUCell(sd, x, hLast, weights);
|
||||
return new GRUCellOutputs(c.outputVariables(baseName));
|
||||
}
|
||||
/**
|
||||
* The LSTM cell. Does a single time step operation.<br>
|
||||
*
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param LSTMWeights Configuration Object
|
||||
* @param LSTMConfiguration Configuration Object
|
||||
* @return output The cell's outputs (NUMERIC type)
|
||||
*/
|
||||
public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||
SDValidation.validateNumerical("lstmCell", "x", x);
|
||||
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
||||
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
||||
LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
|
||||
return new LSTMCellOutputs(c.outputVariables());
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}.
|
||||
*/
|
||||
public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
LSTMWeights weights, LSTMConfiguration config){
|
||||
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
|
||||
return new LSTMCellOutputs(c.outputVariables());
|
||||
}
|
||||
/**
|
||||
* The LSTM cell. Does a single time step operation.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param LSTMWeights Configuration Object
|
||||
* @param LSTMConfiguration Configuration Object
|
||||
* @return output The cell's outputs (NUMERIC type)
|
||||
*/
|
||||
public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast,
|
||||
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||
SDValidation.validateNumerical("lstmCell", "x", x);
|
||||
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
|
||||
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
|
||||
LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
|
||||
return new LSTMCellOutputs(c.outputVariables(name));
|
||||
}
|
||||
|
||||
/**
|
||||
* The LSTM cell. Does a single time step operation.
|
||||
*
|
||||
* @param baseName The base name for the lstm cell
|
||||
* @param x Input, with shape [batchSize, inSize]
|
||||
* @param cLast Previous cell state, with shape [batchSize, numUnits]
|
||||
* @param yLast Previous cell output, with shape [batchSize, numUnits]
|
||||
* @param weights The cell's weights.
|
||||
* @param config The cell's config.
|
||||
* @return The cell's outputs.
|
||||
*/
|
||||
public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config);
|
||||
return new LSTMCellOutputs(c.outputVariables(baseName));
|
||||
}
|
||||
/**
|
||||
* The LSTM layer. Does multiple time steps.<br>
|
||||
*
|
||||
* @param maxTSLength (NUMERIC type)
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param LSTMWeights Configuration Object
|
||||
* @param LSTMConfiguration Configuration Object
|
||||
* @return output The layer's outputs. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
||||
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
||||
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
||||
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||
*/
|
||||
public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength,
|
||||
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
|
||||
return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat());
|
||||
}
|
||||
/**
|
||||
* The LSTM layer. Does multiple time steps.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param maxTSLength (NUMERIC type)
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
|
||||
* @param LSTMWeights Configuration Object
|
||||
* @param LSTMConfiguration Configuration Object
|
||||
* @return output The layer's outputs. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast,
|
||||
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
|
||||
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
|
||||
SDValidation.validateNumerical("lstmLayer", "x", x);
|
||||
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
|
||||
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||
*/
|
||||
public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||
return lstmLayer(
|
||||
sd.scalar("lstm_max_ts_length", maxTSLength),
|
||||
x, cLast, yLast, weights, config);
|
||||
}
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs.. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sru(SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sru", "x", x);
|
||||
SDValidation.validateNumerical("sru", "initialC", initialC);
|
||||
SDValidation.validateNumerical("sru", "mask", mask);
|
||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}
|
||||
*/
|
||||
public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||
if(baseName != null) {
|
||||
return lstmLayer(baseName,
|
||||
sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength),
|
||||
x, cLast, yLast, weights, config);
|
||||
} else {
|
||||
return lstmLayer(maxTSLength, x, cLast, yLast, weights, config);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs.. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sru(String name, SDVariable x, SDVariable initialC, SDVariable mask,
|
||||
SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sru", "x", x);
|
||||
SDValidation.validateNumerical("sru", "initialC", initialC);
|
||||
SDValidation.validateNumerical("sru", "mask", mask);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* The LSTM layer. Does multiple time steps.
|
||||
*
|
||||
* Input shape depends on data format (in config):<br>
|
||||
* TNS -> [timeSteps, batchSize, inSize]<br>
|
||||
* NST -> [batchSize, inSize, timeSteps]<br>
|
||||
* NTS -> [batchSize, timeSteps, inSize]<br>
|
||||
*
|
||||
* @param baseName The base name for the lstm layer
|
||||
* @param x Input, with shape dependent on the data format (in config).
|
||||
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits]
|
||||
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits]
|
||||
* @param weights The layer's weights.
|
||||
* @param config The layer's config.
|
||||
* @return The layer's outputs.
|
||||
*/
|
||||
public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
|
||||
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
|
||||
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
|
||||
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
|
||||
return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
|
||||
}
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs.. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sru(SDVariable x, SDVariable initialC, SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sru", "x", x);
|
||||
SDValidation.validateNumerical("sru", "initialC", initialC);
|
||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}.
|
||||
*/
|
||||
public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
|
||||
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables());
|
||||
}
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs.. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sru(String name, SDVariable x, SDVariable initialC, SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sru", "x", x);
|
||||
SDValidation.validateNumerical("sru", "initialC", initialC);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* The SRU cell. Does a single time step operation.
|
||||
*
|
||||
* @param baseName The base name for the sru cell
|
||||
* @param x Input, with shape [batchSize, inSize]
|
||||
* @param cLast Previous cell state, with shape [batchSize, inSize]
|
||||
* @param weights The cell's weights.
|
||||
* @return The cell's outputs.
|
||||
*/
|
||||
public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) {
|
||||
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName));
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||
*/
|
||||
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
|
||||
return sru(x, initialC, null, weights);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||
*/
|
||||
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
|
||||
return sru(baseName, x, initialC, null, weights);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
|
||||
*/
|
||||
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
|
||||
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
|
||||
}
|
||||
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.
|
||||
*
|
||||
* @param baseName The base name for the sru layer
|
||||
* @param x Input, with shape [batchSize, inSize, timeSeriesLength]
|
||||
* @param initialC Initial cell state, with shape [batchSize, inSize]
|
||||
* @param mask An optional dropout mask, with shape [batchSize, inSize]
|
||||
* @param weights The layer's weights.
|
||||
* @return The layer's outputs.
|
||||
*/
|
||||
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
|
||||
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
|
||||
}
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sruCell(SDVariable x, SDVariable cLast, SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sruCell", "x", x);
|
||||
SDValidation.validateNumerical("sruCell", "cLast", cLast);
|
||||
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* The SRU layer. Does a single time step operation.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type)
|
||||
* @param SRUWeights Configuration Object
|
||||
* @return output The cell's outputs. (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sruCell(String name, SDVariable x, SDVariable cLast, SRUWeights SRUWeights) {
|
||||
SDValidation.validateNumerical("sruCell", "x", x);
|
||||
SDValidation.validateNumerical("sruCell", "cLast", cLast);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,324 +14,253 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
|
||||
/**
|
||||
* SameDiff random number generator operations<br>
|
||||
* Accessible via {@link SameDiff#random()}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SDRandom extends SDOps {
|
||||
public SDRandom(SameDiff sameDiff) {
|
||||
super(sameDiff);
|
||||
}
|
||||
|
||||
public SDRandom(SameDiff sd) {
|
||||
super(sd);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
|
||||
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
|
||||
* 1-P.<br>
|
||||
*
|
||||
* @param p Probability of value 1
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable bernoulli(double p, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #bernoulli(String, double, SDVariable)
|
||||
*/
|
||||
public SDVariable bernoulli(double p, SDVariable shape) {
|
||||
return bernoulli(null, p, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
|
||||
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
|
||||
* 1-P.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param p Probability of value 1
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable bernoulli(String name, double p, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution,
|
||||
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability
|
||||
* 1-P.<br>
|
||||
* See {@link #bernoulli(String, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param p Probability of value 1
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable bernoulli(String name, double p, SDVariable shape) {
|
||||
validateInteger("bernoulli random", shape);
|
||||
SDVariable ret = f().randomBernoulli(p, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
|
||||
* with the specified number of trials and probability.<br>
|
||||
*
|
||||
* @param nTrials Number of trials parameter for the binomial distribution
|
||||
* @param p Probability of success for each trial
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable binomial(int nTrials, double p, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #bernoulli(String, double, long...)
|
||||
*/
|
||||
public SDVariable bernoulli(double p, long... shape) {
|
||||
return bernoulli(null, p, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
|
||||
* with the specified number of trials and probability.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param nTrials Number of trials parameter for the binomial distribution
|
||||
* @param p Probability of success for each trial
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable binomial(String name, int nTrials, double p, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution,
|
||||
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability
|
||||
* 1-P.<br>
|
||||
* See {@link #bernoulli(String, double, SDVariable)} for the equivalent function where the shape is
|
||||
* specified as a SDVarible instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param p Probability of value 1
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable bernoulli(String name, double p, long... shape) {
|
||||
SDVariable ret = f().randomBernoulli(p, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
|
||||
* P(x) = lambda * exp(-lambda * x)<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be positive: lambda > 0<br>
|
||||
*
|
||||
* @param lambda lambda parameter
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable exponential(double lambda, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
Preconditions.checkArgument(lambda > 0, "Must be positive");
|
||||
return new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||
* with the specified number of trials and probability.
|
||||
*
|
||||
* @param nTrials Number of trials parameter for the binomial distribution
|
||||
* @param p Probability of success for each trial
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable binomial(int nTrials, double p, long... shape) {
|
||||
return binomial(null, nTrials, p, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
|
||||
* P(x) = lambda * exp(-lambda * x)<br>
|
||||
*
|
||||
* Inputs must satisfy the following constraints: <br>
|
||||
* Must be positive: lambda > 0<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param lambda lambda parameter
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable exponential(String name, double lambda, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
Preconditions.checkArgument(lambda > 0, "Must be positive");
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||
* with the specified number of trials and probability.
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param nTrials Number of trials parameter for the binomial distribution
|
||||
* @param p Probability of success for each trial
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable binomial(String name, int nTrials, double p, long... shape) {
|
||||
SDVariable ret = f().randomBinomial(nTrials, p, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
|
||||
* i.e., {@code log(x) ~ N(mean, stdev)}<br>
|
||||
*
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable logNormal(double mean, double stddev, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution:
|
||||
* P(x) = lambda * exp(-lambda * x)
|
||||
*
|
||||
* @param lambda Must be > 0
|
||||
* @param shape Shape of the output
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable exponential(double lambda, SDVariable shape) {
|
||||
return exponential(null, lambda, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
|
||||
* i.e., {@code log(x) ~ N(mean, stdev)}<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable logNormal(String name, double mean, double stddev, DataType datatype,
|
||||
long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution:
|
||||
* P(x) = lambda * exp(-lambda * x)
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param lambda Must be > 0
|
||||
* @param shape Shape of the new variable
|
||||
* @return new SDVaribale
|
||||
*/
|
||||
public SDVariable exponential(String name, double lambda, SDVariable shape) {
|
||||
validateInteger("exponential random", shape);
|
||||
SDVariable ret = f().randomExponential(lambda, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||
* N(mean, stdev)<br>
|
||||
*
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable normal(double mean, double stddev, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #logNormal(String, double, double, long...)
|
||||
*/
|
||||
public SDVariable logNormal(double mean, double stddev, long... shape) {
|
||||
return logNormal(null, mean, stddev, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||
* N(mean, stdev)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable normal(String name, double mean, double stddev, DataType datatype,
|
||||
long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution,
|
||||
* i.e., {@code log(x) ~ N(mean, stdev)}<br>
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param shape Shape of the new random SDVariable
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable logNormal(String name, double mean, double stddev, long... shape) {
|
||||
SDVariable ret = f().randomLogNormal(mean, stddev, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
|
||||
*
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable normalTruncated(double mean, double stddev, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #normal(String, double, double, SDVariable)
|
||||
*/
|
||||
public SDVariable normal(double mean, double stddev, SDVariable shape) {
|
||||
return normal(null, mean, stddev, shape);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
|
||||
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable normalTruncated(String name, double mean, double stddev, DataType datatype,
|
||||
long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
|
||||
* N(mean, stdev)<br>
|
||||
* See {@link #normal(String, double, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable normal(String name, double mean, double stddev, SDVariable shape) {
|
||||
validateInteger("normal (Gaussian) random", shape);
|
||||
SDVariable ret = f().randomNormal(mean, stddev, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #normal(String, double, double, long...)
|
||||
*/
|
||||
public SDVariable normal(double mean, double stddev, long... shape) {
|
||||
return normal(null, mean, stddev, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
|
||||
* N(mean, stdev)<br>
|
||||
* See {@link #normal(String, double, double, SDVariable)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param shape Shape of the new random SDVariable
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable normal(String name, double mean, double stddev, long... shape) {
|
||||
SDVariable ret = f().randomNormal(mean, stddev, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #normalTruncated(String, double, double, long...)
|
||||
*/
|
||||
public SDVariable normalTruncated(double mean, double stddev, long... shape) {
|
||||
return normalTruncated(null, mean, stddev, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
|
||||
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param mean Mean value for the random array
|
||||
* @param stddev Standard deviation for the random array
|
||||
* @param shape Shape of the new random SDVariable
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable normalTruncated(String name, double mean, double stddev, long... shape) {
|
||||
SDVariable ret = f().randomNormalTruncated(mean, stddev, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #uniform(String, double, double, SDVariable)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, SDVariable shape) {
|
||||
return uniform(null, min, max, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #uniform(String, double, double, SDVariable)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
|
||||
return uniform(null, min, max, shape, dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
|
||||
return uniform(name, min, max, shape, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
||||
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
|
||||
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a long[] instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value. Must satisfy max >= min
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array
|
||||
* @param dataType Data type of the output array (if null: Float32 output is returned)
|
||||
* @return New SDVariable, of the specified data type
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
|
||||
validateInteger("uniform random", shape);
|
||||
SDVariable ret = f().randomUniform(min, max, shape, dataType);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* @see #uniform(String, double, double, long...)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, long... shape) {
|
||||
return uniform(null, min, max, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
|
||||
* U(min,max)<br>
|
||||
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
|
||||
* specified as a SDVariable instead
|
||||
*
|
||||
* @param name Name of the new SDVariable
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value. Must satisfy max >= min
|
||||
* @param shape Shape of the new random SDVariable
|
||||
* @return New SDVariable
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, long... shape) {
|
||||
SDVariable ret = f().randomUniform(min, max, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable with Gamma distribution
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param alpha distribution parameter
|
||||
* @param beta distribution parameter
|
||||
* @param shape Shape of the new variable
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) {
|
||||
SDVariable ret = f().randomGamma(alpha, beta, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable with Poission distribution
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param lambda rate distribution parameter
|
||||
* @param shape Shape of the new variable
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) {
|
||||
SDVariable ret = f().randomPoisson(shape, lambda, seeds);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable by random shuffle
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param value array to shuffle
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable shuffle(String name, SDVariable value, int... seeds) {
|
||||
SDVariable ret = f().randomShuffle(value, seeds);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
|
||||
* U(min,max)<br>
|
||||
*
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value.
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable uniform(double min, double max, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
return new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
|
||||
* U(min,max)<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param min Minimum value
|
||||
* @param max Maximum value.
|
||||
* @param datatype Data type of the output variable
|
||||
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
|
||||
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
|
||||
*/
|
||||
public SDVariable uniform(String name, double min, double max, DataType datatype, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,6 +55,15 @@ public class SDValidation {
|
|||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
protected static void validateNumerical(String opName, String inputName, SDVariable[] vars) {
|
||||
for (SDVariable v : vars) {
|
||||
if (v == null) continue;
|
||||
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
|
||||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate that the operation is being applied on numerical SDVariables (not boolean or utf8).
|
||||
* Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
|
||||
|
@ -97,6 +106,16 @@ public class SDValidation {
|
|||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
protected static void validateInteger(String opName, String inputName, SDVariable[] vars) {
|
||||
for (SDVariable v : vars) {
|
||||
if (v == null)
|
||||
return;
|
||||
if (!v.dataType().isIntType())
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
|
||||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate that the operation is being applied on an floating point type SDVariable
|
||||
*
|
||||
|
@ -200,4 +219,18 @@ public class SDValidation {
|
|||
}
|
||||
}
|
||||
|
||||
public static boolean isSameType(SDVariable x, SDVariable y) {
|
||||
return x.dataType() == y.dataType();
|
||||
}
|
||||
|
||||
public static boolean isSameType(SDVariable[] x) {
|
||||
DataType firstDataType = x[0].dataType();
|
||||
if (x.length > 1) {
|
||||
for (int i = 1; i < x.length; ++i) {
|
||||
if (firstDataType != x[i].dataType())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.enums;
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* Data format: "NCHW" or "NHWC" */
|
|
@ -633,7 +633,9 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.Lu.class,
|
||||
org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
|
||||
org.nd4j.linalg.api.ops.custom.LinearSolve.class,
|
||||
org.nd4j.linalg.api.ops.custom.Lstsq.class
|
||||
org.nd4j.linalg.api.ops.custom.Lstsq.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Qr.class,
|
||||
org.nd4j.linalg.api.ops.custom.Logdet.class
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
|
@ -85,6 +85,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
|||
this(x, null, dimensions);
|
||||
}
|
||||
|
||||
public BaseIndexAccumulation(INDArray x, boolean keepDims, int[] dimensions) {
|
||||
this(x, null, dimensions);
|
||||
this.keepDims = keepDims;
|
||||
defineDimensions(dimensions);
|
||||
}
|
||||
|
||||
public BaseIndexAccumulation(INDArray x, INDArray z, int[] dimensions) {
|
||||
super(x, z);
|
||||
defineDimensions(dimensions);
|
||||
|
|
|
@ -29,12 +29,17 @@ public class AdjustContrast extends BaseAdjustContrast {
|
|||
super(in, factor, out);
|
||||
}
|
||||
|
||||
public AdjustContrast(@NonNull INDArray in, double factor) {
|
||||
this(in, factor, null);
|
||||
}
|
||||
|
||||
public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
|
||||
super(sameDiff,new SDVariable[]{in,factor});
|
||||
}
|
||||
|
||||
public AdjustContrast(@NonNull INDArray in, double factor) {
|
||||
this(in, factor, null);
|
||||
public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
|
||||
super(sameDiff,new SDVariable[]{in});
|
||||
addTArgument(factor);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -50,6 +50,11 @@ public class AdjustHue extends DynamicCustomOp {
|
|||
super(sameDiff,new SDVariable[]{in,factor});
|
||||
}
|
||||
|
||||
public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
|
||||
super(sameDiff,new SDVariable[]{in});
|
||||
addTArgument(factor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "adjust_hue";
|
||||
|
|
|
@ -49,6 +49,11 @@ public class AdjustSaturation extends DynamicCustomOp {
|
|||
super(sameDiff, new SDVariable[]{in, factor});
|
||||
}
|
||||
|
||||
public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
|
||||
super(sameDiff, new SDVariable[]{in});
|
||||
addTArgument(factor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "adjust_saturation";
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Logdet extends DynamicCustomOp {
|
||||
|
||||
public Logdet(INDArray input) {
|
||||
addInputArgument(input);
|
||||
}
|
||||
|
||||
public Logdet(SameDiff sameDiff, SDVariable input) {
|
||||
super(sameDiff, new SDVariable[]{input});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "logdet";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -17,9 +17,17 @@ package org.nd4j.linalg.api.ops.custom;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Lstsq extends DynamicCustomOp {
|
||||
|
||||
|
@ -33,8 +41,21 @@ public class Lstsq extends DynamicCustomOp {
|
|||
this(matrix, rhs, 0.0, true);
|
||||
}
|
||||
|
||||
public Lstsq(@NonNull SameDiff sameDiff, @NonNull SDVariable matrix, @NonNull SDVariable rhs, double l2_regularizer, boolean fast) {
|
||||
super(sameDiff, new SDVariable[]{matrix,rhs});
|
||||
addTArgument(l2_regularizer);
|
||||
addBArgument(fast);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "lstsq";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -26,10 +27,9 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class MatrixBandPart extends DynamicCustomOp {
|
||||
|
||||
public MatrixBandPart() {}
|
||||
|
||||
public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) {
|
||||
Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher");
|
||||
long N = input.size(-2);
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -37,7 +36,6 @@ import java.util.*;
|
|||
*/
|
||||
@NoArgsConstructor
|
||||
public class CropAndResize extends DynamicCustomOp {
|
||||
|
||||
public enum Method {BILINEAR, NEAREST};
|
||||
protected Method method = Method.BILINEAR;
|
||||
protected double extrapolationValue = 0.0;
|
||||
|
@ -50,6 +48,10 @@ public class CropAndResize extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public CropAndResize(@NonNull SameDiff sameDiff, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
|
||||
SDVariable cropOutSize, double extrapolationValue) {
|
||||
this(sameDiff, image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue);
|
||||
}
|
||||
|
||||
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
||||
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue,
|
||||
|
@ -65,12 +67,10 @@ public class CropAndResize extends DynamicCustomOp {
|
|||
outputArguments.add(output);
|
||||
}
|
||||
|
||||
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
|
||||
@NonNull INDArray cropOutSize, double extrapolationValue) {
|
||||
this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null);
|
||||
public CropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, INDArray cropOutSize, double extrapolationValue ) {
|
||||
this(image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue, null);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "crop_and_resize";
|
||||
|
|
|
@ -46,6 +46,12 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
|||
|
||||
public ExtractImagePatches(){ }
|
||||
|
||||
public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input,
|
||||
int kH, int kW, int sH, int sW, int rH, int rW,
|
||||
boolean sameMode) {
|
||||
this(samediff, input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH,rW}, sameMode);
|
||||
|
||||
}
|
||||
public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, @NonNull int[] kSizes,
|
||||
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode){
|
||||
super(samediff, input);
|
||||
|
@ -72,16 +78,8 @@ public class ExtractImagePatches extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||
super(new INDArray[]{input},null);
|
||||
int[] kSises = {kH,kW};
|
||||
int[] strides = {sH,sW};
|
||||
int[] rates = {rH, rW};
|
||||
this.kSizes = kSises;
|
||||
this.strides = strides;
|
||||
this.rates = rates;
|
||||
this.isSameMode = sameMode;
|
||||
addArgs();
|
||||
public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
|
||||
this(input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,13 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
|
||||
}
|
||||
|
||||
public NonMaxSuppression(SameDiff sameDiff, SDVariable boxes, SDVariable scores, int maxOutSize,
|
||||
double iouThreshold, double scoreThreshold) {
|
||||
super(null, sameDiff, new SDVariable[]{boxes, scores}, false);
|
||||
addIArgument(maxOutSize);
|
||||
addTArgument(iouThreshold, scoreThreshold);
|
||||
}
|
||||
|
||||
public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) {
|
||||
addInputArgument(boxes,scores);
|
||||
addIArgument(maxOutSize);
|
||||
|
|
|
@ -54,10 +54,18 @@ public class FirstIndex extends BaseIndexAccumulation {
|
|||
this.extraArgs = new Object[] {compare, eps, (double) mode};
|
||||
}
|
||||
|
||||
public FirstIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) {
|
||||
this(sameDiff, i_v, condition, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
|
||||
this(x, condition, false, dimension);
|
||||
}
|
||||
|
||||
public FirstIndex(INDArray x, boolean keepDims, @NonNull Condition condition, int... dimension) {
|
||||
this(x,condition,keepDims,dimension);
|
||||
}
|
||||
|
||||
public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) {
|
||||
this(x, condition, Nd4j.EPS_THRESHOLD, dimension);
|
||||
this.keepDims = keepDims;
|
||||
|
@ -72,7 +80,6 @@ public class FirstIndex extends BaseIndexAccumulation {
|
|||
this.extraArgs = new Object[] {compare, eps, (double) mode};
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 4;
|
||||
|
|
|
@ -45,6 +45,11 @@ public class IMax extends BaseIndexAccumulation {
|
|||
super(x, z, dimensions);
|
||||
}
|
||||
|
||||
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||
super(x, keepDims, dimensions);
|
||||
|
||||
}
|
||||
|
||||
public IMax(INDArray x, int... dimensions) {
|
||||
super(x, null, dimensions);
|
||||
}
|
||||
|
|
|
@ -44,6 +44,10 @@ public class IMin extends BaseIndexAccumulation {
|
|||
super(x, dimensions);
|
||||
}
|
||||
|
||||
public IMin(INDArray x, boolean keepDims, int... dimensions) {
|
||||
super(x, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public IMin(INDArray x, INDArray z, int... dimensions) {
|
||||
super(x, z, dimensions);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -38,12 +39,16 @@ import java.util.Map;
|
|||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class LastIndex extends BaseIndexAccumulation {
|
||||
protected Condition condition;
|
||||
protected double compare;
|
||||
protected double eps;
|
||||
protected int mode;
|
||||
|
||||
public LastIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) {
|
||||
this(sameDiff, i_v, condition, keepDims, dimensions);
|
||||
}
|
||||
public LastIndex(SameDiff sameDiff, SDVariable i_v, Condition condition, boolean keepDims, int... dimensions) {
|
||||
super(sameDiff, i_v, keepDims, dimensions);
|
||||
this.condition = condition;
|
||||
|
@ -53,13 +58,19 @@ public class LastIndex extends BaseIndexAccumulation {
|
|||
this.extraArgs = new Object[] {compare, eps, (double) mode};
|
||||
}
|
||||
|
||||
public LastIndex() {}
|
||||
|
||||
public LastIndex(SameDiff sameDiff, SDVariable x, @NonNull Condition condition, int... dimensions) {
|
||||
super(sameDiff, x, false, dimensions);
|
||||
this.condition = condition;
|
||||
}
|
||||
|
||||
public LastIndex(INDArray x, @NonNull Condition condition, int... dimensions) {
|
||||
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
|
||||
}
|
||||
|
||||
public LastIndex(INDArray in, boolean keepDim, Condition condition, int... dimensions) {
|
||||
this(in, condition, keepDim, dimensions);
|
||||
}
|
||||
|
||||
public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) {
|
||||
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
|
||||
this.keepDims = keepDim;
|
||||
|
|
|
@ -47,10 +47,6 @@ public class AvgPooling3D extends Pooling3D {
|
|||
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG);
|
||||
}
|
||||
|
||||
public AvgPooling3D(SameDiff sameDiff,INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) {
|
||||
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG);
|
||||
}
|
||||
|
||||
public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) {
|
||||
super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG);
|
||||
}
|
||||
|
|
|
@ -76,6 +76,19 @@ public class BatchNorm extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public BatchNorm(SameDiff sameDiff, SDVariable input, SDVariable mean, SDVariable variance,
|
||||
SDVariable gamma, SDVariable beta, double epsilon, int[] axis) {
|
||||
super(null,sameDiff, wrapFilterNull(input, mean, variance, gamma, beta), false);
|
||||
Preconditions.checkState(axis != null && axis.length > 0, "Invalid axis argument: axis must be specified" +
|
||||
"and length > 0. Got %s", axis);
|
||||
this.sameDiff = sameDiff;
|
||||
this.applyBeta = beta != null;
|
||||
this.applyGamma = gamma != null;
|
||||
this.epsilon = epsilon;
|
||||
this.jaxis = axis;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){
|
||||
super(wrapFilterNull(input, mean, variance, gamma, beta), null);
|
||||
this.jaxis = axis;
|
||||
|
|
|
@ -46,6 +46,10 @@ public class Conv1D extends DynamicCustomOp {
|
|||
protected Conv1DConfig config;
|
||||
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
|
||||
|
||||
public Conv1D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), conv1DConfig);
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv1D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
|
@ -64,12 +68,8 @@ public class Conv1D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) {
|
||||
this(wrapFilterNull(input, weights, bias), null, conv1DConfig);
|
||||
}
|
||||
|
||||
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) {
|
||||
this(new INDArray[]{input, weights}, null, conv1DConfig);
|
||||
public Conv1D(INDArray input, INDArray weights, INDArray bias, Conv1DConfig config) {
|
||||
this(input, weights, bias, null, config);
|
||||
}
|
||||
|
||||
private void initConfig(Conv1DConfig config){
|
||||
|
|
|
@ -56,6 +56,11 @@ public class Conv2D extends DynamicCustomOp {
|
|||
protected Conv2DConfig config;
|
||||
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
|
||||
|
||||
public Conv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
|
||||
SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
|
@ -75,12 +80,8 @@ public class Conv2D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(new INDArray[]{layerInput, weights}, null, conv2DConfig);
|
||||
}
|
||||
|
||||
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig);
|
||||
public Conv2D(INDArray layerInput, INDArray weights, INDArray bias, Conv2DConfig config) {
|
||||
this(layerInput, weights, bias, null, config);
|
||||
}
|
||||
|
||||
protected void initConfig(Conv2DConfig config){
|
||||
|
|
|
@ -55,6 +55,11 @@ public class Conv3D extends DynamicCustomOp {
|
|||
public Conv3D() {
|
||||
}
|
||||
|
||||
public Conv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
|
||||
SDVariable bias, @NonNull Conv3DConfig config) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), config);
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
|
||||
super(sameDiff, inputFunctions);
|
||||
|
@ -70,12 +75,12 @@ public class Conv3D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) {
|
||||
this(new INDArray[]{input, weights}, null, conv3DConfig);
|
||||
public Conv3D(INDArray input, INDArray weights, INDArray bias, Conv3DConfig config) {
|
||||
this(wrapFilterNull(input, weights, bias), null, config);
|
||||
}
|
||||
|
||||
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) {
|
||||
this(wrapFilterNull(input, weights, bias) , null, conv3DConfig);
|
||||
public Conv3D(INDArray input, INDArray weights, Conv3DConfig config) {
|
||||
this(wrapFilterNull(input, weights), null, config);
|
||||
}
|
||||
|
||||
private void initConfig(Conv3DConfig config){
|
||||
|
|
|
@ -52,6 +52,11 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
|
||||
protected DeConv2DConfig config;
|
||||
|
||||
public DeConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
|
||||
SDVariable bias, DeConv2DConfig config) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), config);
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public DeConv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputs,
|
||||
|
@ -73,15 +78,10 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) {
|
||||
this(wrapFilterNull(layerInput, weights), null, deConv2DConfig);
|
||||
public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig config) {
|
||||
this(layerInput, weights, bias, null, config);
|
||||
}
|
||||
|
||||
public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) {
|
||||
this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long[] iArgs() {
|
||||
if (iArguments.size() == 0)
|
||||
|
|
|
@ -48,12 +48,18 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
|
||||
protected DeConv3DConfig config;
|
||||
|
||||
public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
|
||||
super(sameDiff, toArr(input, weights, bias));
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
|
||||
super(sameDiff, toArr(input, weights, null));
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
|
||||
super(inputs, outputs);
|
||||
|
||||
|
@ -65,12 +71,8 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) {
|
||||
this(new INDArray[]{input, weights}, null, deConv3DConfig);
|
||||
}
|
||||
|
||||
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) {
|
||||
this(wrapFilterNull(input, weights, bias), null, deConv3DConfig);
|
||||
public DeConv3D(INDArray input, INDArray weights, INDArray bias, DeConv3DConfig config) {
|
||||
this(input, weights, bias, null, config);
|
||||
}
|
||||
|
||||
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){
|
||||
|
|
|
@ -16,16 +16,15 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.enums.DataFormat;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.enums.DataFormat;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -46,45 +45,48 @@ import java.util.*;
|
|||
* @author raver119@gmail.com, Max Pumperla
|
||||
*/
|
||||
public class DepthToSpace extends DynamicCustomOp {
|
||||
private String dataFormat = "NHWC";
|
||||
private DataFormat dataFormat = DataFormat.NHWC;
|
||||
private int blockSize;
|
||||
|
||||
public DepthToSpace() {
|
||||
}
|
||||
|
||||
public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) {
|
||||
public DepthToSpace(SameDiff sameDiff, SDVariable args, int blockSize, DataFormat dataFormat) {
|
||||
this(sameDiff, new SDVariable[]{args}, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) {
|
||||
super(null, sameDiff, args, false);
|
||||
this.blockSize = blockSize;
|
||||
this.dataFormat = dataFormat;
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) {
|
||||
public DepthToSpace(INDArray in, INDArray out, int blockSize, DataFormat dataFormat) {
|
||||
super(null, in, out, null, null);
|
||||
this.blockSize = blockSize;
|
||||
this.dataFormat = dataFormat;
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) {
|
||||
this(x, null, blockSize, dataFormat.toString());
|
||||
public DepthToSpace(INDArray in, int blockSize, DataFormat dataFormat) {
|
||||
this(in, null, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
// Gradient to DepthToSpace is just SpaceToDepth of same block size and data format.
|
||||
SDVariable gradient = i_v.get(0);
|
||||
SDVariable ret = sameDiff.cnn().spaceToDepth(gradient, blockSize, dataFormat);
|
||||
SDVariable ret = new SpaceToDepth(sameDiff, new SDVariable[]{gradient}, blockSize, dataFormat).outputVariable();
|
||||
return Arrays.asList(ret);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -49,11 +52,15 @@ import java.util.*;
|
|||
*/
|
||||
@Slf4j
|
||||
@Getter
|
||||
@NoArgsConstructor
|
||||
public class DepthwiseConv2D extends DynamicCustomOp {
|
||||
|
||||
protected Conv2DConfig config;
|
||||
|
||||
public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input,
|
||||
@NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
|
||||
}
|
||||
|
||||
@Builder(builderMethodName = "sameDiffBuilder")
|
||||
public DepthwiseConv2D(SameDiff sameDiff,
|
||||
SDVariable[] inputFunctions,
|
||||
|
@ -75,16 +82,11 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) {
|
||||
this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig);
|
||||
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig config) {
|
||||
this(layerInput, depthWeights, bias, null, config);
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) {
|
||||
this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig);
|
||||
}
|
||||
|
||||
public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) {
|
||||
this(wrapFilterNull(inputs), null, conv2DConfig);
|
||||
public DepthwiseConv2D() {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -58,6 +58,10 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public LocalResponseNormalization(SameDiff sameDiff, SDVariable input, LocalResponseNormalizationConfig config) {
|
||||
this(sameDiff, new SDVariable[]{input}, false, config);
|
||||
}
|
||||
|
||||
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
|
||||
super(new INDArray[]{input}, wrapOrNull(output));
|
||||
|
||||
|
|
|
@ -60,15 +60,16 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
|
||||
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||
config.setType(Pooling2D.Pooling2DType.MAX);
|
||||
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) {
|
||||
this(input, null, pooling2DConfig);
|
||||
public MaxPooling2D(INDArray input, @NonNull Pooling2DConfig config){
|
||||
this(input, null, config);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -47,8 +47,12 @@ public class MaxPooling3D extends Pooling3D {
|
|||
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX);
|
||||
}
|
||||
|
||||
public MaxPooling3D(SameDiff sameDiff, INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) {
|
||||
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX);
|
||||
public MaxPooling3D(INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) {
|
||||
addInputArgument(arrayInput);
|
||||
if (arrayOutput != null)
|
||||
addOutputArgument(arrayOutput);
|
||||
this.config = config;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) {
|
||||
|
|
|
@ -44,18 +44,23 @@ public class SConv2D extends Conv2D {
|
|||
super(sameDiff, inputFunctions, conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights,
|
||||
@NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
|
||||
this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
|
||||
super(inputs, outputs, config);
|
||||
}
|
||||
|
||||
public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){
|
||||
this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){
|
||||
this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig);
|
||||
}
|
||||
|
||||
public SConv2D(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, INDArray bias, Conv2DConfig config) {
|
||||
this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, config);
|
||||
}
|
||||
|
||||
public SConv2D() {}
|
||||
|
||||
|
||||
|
|
|
@ -16,16 +16,15 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.enums.DataFormat;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.enums.DataFormat;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -45,47 +44,48 @@ import java.util.*;
|
|||
* @author raver119@gmail.com, Max Pumperla
|
||||
*/
|
||||
public class SpaceToDepth extends DynamicCustomOp {
|
||||
private String dataFormat;
|
||||
private DataFormat dataFormat;
|
||||
private int blockSize;
|
||||
|
||||
public SpaceToDepth() {
|
||||
}
|
||||
|
||||
public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) {
|
||||
public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) {
|
||||
super(null, sameDiff, args, false);
|
||||
this.blockSize = blockSize;
|
||||
this.dataFormat = dataFormat;
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
public SpaceToDepth(INDArray in, INDArray out, int blockSize, String dataFormat){
|
||||
public SpaceToDepth(SameDiff sameDiff, SDVariable x, int blockSize, DataFormat dataFormat) {
|
||||
this(sameDiff, new SDVariable[]{x}, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
public SpaceToDepth(INDArray in, INDArray out, int blockSize, DataFormat dataFormat){
|
||||
super(null, in, out, null, null);
|
||||
this.blockSize = blockSize;
|
||||
this.dataFormat = dataFormat;
|
||||
boolean isNHWC = dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) {
|
||||
this(x, null, blockSize,dataFormat.toString());
|
||||
public SpaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) {
|
||||
this(x, null, blockSize, dataFormat);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
// Gradient to SpaceToDepth is just DepthToSpace of same block size and data format.
|
||||
SDVariable gradient = i_v.get(0);
|
||||
SDVariable ret = sameDiff.cnn().depthToSpace(gradient, blockSize, dataFormat);
|
||||
SDVariable ret = new DepthToSpace(sameDiff, gradient, blockSize, dataFormat).outputVariable();
|
||||
return Arrays.asList(ret);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC");
|
||||
boolean isNHWC = dataFormat == null ? true : dataFormat.equals(DataFormat.NHWC);
|
||||
addIArgument(blockSize, isNHWC ? 1 : 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -56,6 +56,14 @@ public class Upsampling2d extends DynamicCustomOp {
|
|||
addIArgument(nchw ? 1 : 0);
|
||||
}
|
||||
|
||||
public Upsampling2d(SameDiff sameDiff, SDVariable input, int scaleH, int scaleW, boolean nchw) {
|
||||
this(sameDiff, input, nchw, scaleH, scaleW);
|
||||
}
|
||||
|
||||
public Upsampling2d(SameDiff sameDiff, SDVariable input, int scale) {
|
||||
super(null,sameDiff, new SDVariable[]{input});
|
||||
addIArgument(scale);
|
||||
}
|
||||
|
||||
public Upsampling2d(INDArray input, int scale) {
|
||||
this(input, scale, scale, true);
|
||||
|
|
|
@ -38,6 +38,11 @@ public class AbsoluteDifferenceLoss extends BaseLoss {
|
|||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public AbsoluteDifferenceLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, label);
|
||||
}
|
||||
|
||||
public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
|
|
@ -33,9 +33,9 @@ public abstract class BaseLoss extends DynamicCustomOp {
|
|||
|
||||
protected LossReduce lossReduce;
|
||||
|
||||
public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, @NonNull SDVariable weights,
|
||||
public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, SDVariable weights,
|
||||
@NonNull SDVariable labels){
|
||||
super(null, sameDiff, new SDVariable[]{predictions, weights, labels});
|
||||
super(null, sameDiff, new SDVariable[]{predictions, getWeights(sameDiff, weights, predictions), labels});
|
||||
this.lossReduce = lossReduce;
|
||||
addArgs();
|
||||
}
|
||||
|
@ -50,6 +50,10 @@ public abstract class BaseLoss extends DynamicCustomOp {
|
|||
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
|
||||
}
|
||||
|
||||
protected static SDVariable getWeights(SameDiff sd, SDVariable weights, SDVariable predictions){
|
||||
return weights != null ? weights : sd.constant(Nd4j.scalar(predictions.dataType(), 1.0));
|
||||
}
|
||||
|
||||
protected BaseLoss(){ }
|
||||
|
||||
protected void addArgs(){
|
||||
|
@ -62,7 +66,7 @@ public abstract class BaseLoss extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2, "Expected exactly 2 or more input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,6 +39,11 @@ public class CosineDistanceLoss extends BaseLoss {
|
|||
this.addIArgument(dimension);
|
||||
}
|
||||
|
||||
public CosineDistanceLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce, int dimension) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels, dimension);
|
||||
}
|
||||
|
||||
public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
this.dimension = dimension;
|
||||
|
|
|
@ -36,6 +36,11 @@ public class HingeLoss extends BaseLoss {
|
|||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public HingeLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
|
|
@ -41,6 +41,11 @@ public class HuberLoss extends BaseLoss {
|
|||
tArguments.add(delta);
|
||||
}
|
||||
|
||||
public HuberLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce, double delta) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels, delta);
|
||||
}
|
||||
|
||||
public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
this.delta = delta;
|
||||
|
|
|
@ -41,6 +41,11 @@ public class LogLoss extends BaseLoss {
|
|||
addTArgument(epsilon);
|
||||
}
|
||||
|
||||
public LogLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce, double epsilon) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels, epsilon);
|
||||
}
|
||||
|
||||
public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
this.epsilon = epsilon;
|
||||
|
|
|
@ -38,6 +38,11 @@ public class LogPoissonLoss extends BaseLoss {
|
|||
this(sameDiff, lossReduce, predictions, weights, labels, false);
|
||||
}
|
||||
|
||||
public LogPoissonLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce, boolean full) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels, full);
|
||||
}
|
||||
|
||||
public LogPoissonLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels, boolean full){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
this.full = full;
|
||||
|
|
|
@ -34,6 +34,11 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss {
|
|||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public MeanPairwiseSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions,
|
||||
SDVariable weights, LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
|
|
@ -36,6 +36,11 @@ public class MeanSquaredErrorLoss extends BaseLoss {
|
|||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public MeanSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
|
|
@ -44,6 +44,11 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
|
|||
public static final double DEFAULT_LABEL_SMOOTHING = 0.0;
|
||||
private double labelSmoothing = 0.0;
|
||||
|
||||
public SigmoidCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, SDVariable weights,
|
||||
LossReduce lossReduce, double labelSmoothing) {
|
||||
this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing);
|
||||
}
|
||||
|
||||
public SigmoidCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights,
|
||||
SDVariable labels, double labelSmoothing) {
|
||||
super(sameDiff, lossReduce, logits, weights, labels);
|
||||
|
|
|
@ -45,6 +45,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
|
|||
|
||||
private double labelSmoothing = 0.0;
|
||||
|
||||
public SoftmaxCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits,
|
||||
SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
|
||||
this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing);
|
||||
}
|
||||
|
||||
public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels,
|
||||
double labelSmoothing) {
|
||||
super(sameDiff, lossReduce, logits, weights, labels);
|
||||
|
|
|
@ -93,6 +93,24 @@ public class Mmul extends DynamicCustomOp {
|
|||
}
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
addInputArgument(x, y);
|
||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||
ArrayUtil.fromBoolean(transposeY),
|
||||
ArrayUtil.fromBoolean(transposeZ));
|
||||
}
|
||||
|
||||
public Mmul(INDArray x, INDArray y) {
|
||||
this(x,y,null,null);
|
||||
}
|
||||
|
||||
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
||||
boolean transposeZ) {
|
||||
super(null,sameDiff,new SDVariable[]{x,y});
|
||||
addIArgument(ArrayUtil.fromBoolean(transposeX),
|
||||
ArrayUtil.fromBoolean(transposeY),
|
||||
ArrayUtil.fromBoolean(transposeZ));
|
||||
}
|
||||
|
||||
public Mmul() {}
|
||||
|
||||
|
|
|
@ -77,6 +77,18 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
addIArgument(dimensions[1]);
|
||||
}
|
||||
|
||||
public TensorMmul(SameDiff sameDiff, SDVariable x, SDVariable y, int[] dimensionsX,
|
||||
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
super(null, sameDiff, new SDVariable[]{x,y});
|
||||
this.sameDiff = sameDiff;
|
||||
this.axes = new int[][]{dimensionsX, dimensionsY};
|
||||
addIArgument(dimensionsX.length);
|
||||
addIArgument(dimensionsX[0]);
|
||||
addIArgument(dimensionsY.length);
|
||||
addIArgument(dimensionsY[0]);
|
||||
addBArgument(transposeX, transposeY, transposeZ);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
List<LongShapeDescriptor> ret = new ArrayList<>(1);
|
||||
|
@ -242,6 +254,13 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
this.axes = axes;
|
||||
}
|
||||
|
||||
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
|
||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
super(null,new INDArray[]{x, y},null);
|
||||
this.axes = new int[][]{dimensionsX, dimensionsY};
|
||||
addBArgument(transposeX, transposeY, transposeZ);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "tensordot";
|
||||
|
|
|
@ -41,6 +41,10 @@ public class Any extends BaseReduceBoolOp {
|
|||
super(x);
|
||||
}
|
||||
|
||||
public Any(INDArray x, int... dimensions) {
|
||||
super(x, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 0;
|
||||
|
|
|
@ -45,6 +45,10 @@ public class LogSumExp extends DynamicCustomOp {
|
|||
this.keepDims = keepDims;
|
||||
}
|
||||
|
||||
public LogSumExp(SameDiff sameDiff, SDVariable i_v, int[] dimensions) {
|
||||
this(sameDiff, i_v, false, dimensions);
|
||||
}
|
||||
|
||||
public LogSumExp() {}
|
||||
|
||||
public LogSumExp(INDArray x, int... dimensions) {
|
||||
|
|
|
@ -41,6 +41,10 @@ public class SquaredNorm extends BaseReduceFloatOp {
|
|||
super(input, output, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public SquaredNorm(INDArray input, boolean keepDims, int... dimensions){
|
||||
this(input, null, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public SquaredNorm(){}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -38,6 +38,10 @@ public class MatchCondition extends BaseReduceLongOp {
|
|||
private double eps;
|
||||
private int mode;
|
||||
|
||||
public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition) {
|
||||
this(sameDiff, in, condition, false, null);
|
||||
}
|
||||
|
||||
public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition, boolean keepDims, int... dimensions) {
|
||||
super(sameDiff, in, dimensions, keepDims);
|
||||
this.compare = condition.getValue();
|
||||
|
@ -64,6 +68,10 @@ public class MatchCondition extends BaseReduceLongOp {
|
|||
defineDimensions(dimensions);
|
||||
}
|
||||
|
||||
public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) {
|
||||
this(in, condition, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 2;
|
||||
|
|
|
@ -56,6 +56,10 @@ public class Sum extends BaseReduceSameOp {
|
|||
super(x, z, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public Sum(INDArray x, boolean keepDims, int... dimensions) {
|
||||
this(x, null, keepDims, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 0;
|
||||
|
|
|
@ -50,6 +50,10 @@ public class LeakyReLU extends BaseScalarOp {
|
|||
|
||||
}
|
||||
|
||||
public LeakyReLU(SameDiff sameDiff, SDVariable i_v, double alpha) {
|
||||
this(sameDiff, i_v, false, alpha);
|
||||
}
|
||||
|
||||
public LeakyReLU(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) {
|
||||
super(sameDiff, i_v, alpha, extraArgs);
|
||||
this.alpha = alpha;
|
||||
|
|
|
@ -42,6 +42,10 @@ public class Pow extends BaseScalarOp {
|
|||
this.extraArgs = new Object[]{pow};
|
||||
}
|
||||
|
||||
public Pow(SameDiff sameDiff, SDVariable i_v, double pow) {
|
||||
this(sameDiff, i_v, false, pow);
|
||||
}
|
||||
|
||||
|
||||
public Pow(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double pow) {
|
||||
super(sameDiff, i_v, pow, extraArgs);
|
||||
|
|
|
@ -35,6 +35,10 @@ public class RectifiedLinear extends BaseScalarOp {
|
|||
super(sameDiff, i_v, cutoff, inPlace);
|
||||
}
|
||||
|
||||
public RectifiedLinear(SameDiff sameDiff, SDVariable i_v, double cutoff) {
|
||||
this(sameDiff, i_v, false, cutoff);
|
||||
}
|
||||
|
||||
public RectifiedLinear() {
|
||||
super();
|
||||
}
|
||||
|
|
|
@ -42,6 +42,10 @@ public class Relu6 extends BaseScalarOp {
|
|||
super(sameDiff, i_v, cutoff, inPlace);
|
||||
}
|
||||
|
||||
public Relu6(SameDiff sameDiff, SDVariable i_v, double cutoff) {
|
||||
this(sameDiff, i_v, false, cutoff);
|
||||
}
|
||||
|
||||
public Relu6() {
|
||||
//
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ public class Step extends BaseScalarOp {
|
|||
this.extraArgs = new Object[] {cutoff};
|
||||
}
|
||||
|
||||
public Step(SameDiff sameDiff, SDVariable i_v, double cutoff) {
|
||||
this(sameDiff, i_v, false, cutoff);
|
||||
}
|
||||
|
||||
public Step() {
|
||||
cutoff = 0.0;
|
||||
this.extraArgs = new Object[] {cutoff};
|
||||
|
|
|
@ -46,6 +46,9 @@ public class ScalarLessThan extends BaseScalarBoolOp {
|
|||
super(sameDiff, i_v, scalar, inPlace);
|
||||
}
|
||||
|
||||
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, double scalar) {
|
||||
super(sameDiff, i_v, scalar, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -43,6 +44,10 @@ public class ScatterAdd extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterAdd(){}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -43,6 +44,10 @@ public class ScatterDiv extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterDiv() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -41,6 +42,10 @@ public class ScatterMax extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMax(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMax() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -41,6 +42,10 @@ public class ScatterMin extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMin(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMin() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -43,6 +44,10 @@ public class ScatterMul extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterMul(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterMul() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -43,6 +44,10 @@ public class ScatterSub extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterSub(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterSub() {}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -53,6 +54,10 @@ public class ScatterUpdate extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
|
||||
}
|
||||
|
||||
public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) {
|
||||
addInputArgument(ref, indices, updates);
|
||||
}
|
||||
|
||||
public ScatterUpdate(){}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -49,6 +49,14 @@ public class Concat extends DynamicCustomOp {
|
|||
addIArgument(concatDimension);
|
||||
}
|
||||
|
||||
public Concat(INDArray[] arrays, int concatDimension) {
|
||||
this(concatDimension, arrays);
|
||||
}
|
||||
|
||||
public Concat(SameDiff sameDiff, SDVariable[] inputs, int concatDimension){
|
||||
this(sameDiff, concatDimension, inputs);
|
||||
}
|
||||
|
||||
public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){
|
||||
super(null, sameDiff, inputs);
|
||||
addIArgument(concatDimension);
|
||||
|
|
|
@ -68,6 +68,12 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, DataType dataType){
|
||||
this(sameDiff, labels, pred, weights);
|
||||
this.outputType = dataType;
|
||||
}
|
||||
|
||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
|
||||
super(null, sameDiff, new SDVariable[]{labels, pred});
|
||||
this.outputType = dataType;
|
||||
|
@ -82,6 +88,11 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
|||
addIArgument(numClasses);
|
||||
}
|
||||
|
||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, Integer numClasses){
|
||||
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
|
||||
addIArgument(numClasses);
|
||||
}
|
||||
|
||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){
|
||||
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
|
||||
if(numClasses != null) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -39,15 +40,17 @@ import java.util.List;
|
|||
*
|
||||
* @author Max Pumperla
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class Cross extends DynamicCustomOp {
|
||||
|
||||
public Cross() {
|
||||
}
|
||||
|
||||
public Cross(SameDiff sameDiff, SDVariable[] args) {
|
||||
super(null, sameDiff, args, false);
|
||||
}
|
||||
|
||||
public Cross(SameDiff sameDiff, SDVariable a, SDVariable b) {
|
||||
this(sameDiff, new SDVariable[]{a,b});
|
||||
}
|
||||
|
||||
public Cross(INDArray a, INDArray b){
|
||||
this(a,b,null);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -39,11 +40,9 @@ import java.util.Map;
|
|||
*
|
||||
* @author Max Pumperla
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class Diag extends DynamicCustomOp {
|
||||
|
||||
public Diag() {
|
||||
}
|
||||
|
||||
public Diag(@NonNull INDArray input) {
|
||||
this(input, null);
|
||||
}
|
||||
|
@ -52,6 +51,10 @@ public class Diag extends DynamicCustomOp {
|
|||
super(null, new INDArray[]{input}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
public Diag(SameDiff sameDiff, SDVariable input) {
|
||||
this(sameDiff, new SDVariable[]{input}, false);
|
||||
}
|
||||
|
||||
public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||
super(null, sameDiff, args, inPlace);
|
||||
|
||||
|
|
|
@ -50,6 +50,10 @@ public class DiagPart extends DynamicCustomOp {
|
|||
super(null, sameDiff, args, inPlace);
|
||||
}
|
||||
|
||||
public DiagPart(SameDiff sameDiff, SDVariable in) {
|
||||
this(sameDiff, new SDVariable[]{in}, false);
|
||||
}
|
||||
|
||||
public DiagPart(INDArray in){
|
||||
this(in, null);
|
||||
}
|
||||
|
|
|
@ -46,6 +46,10 @@ public class ExpandDims extends DynamicCustomOp {
|
|||
public ExpandDims() {
|
||||
}
|
||||
|
||||
public ExpandDims(SameDiff sameDiff, SDVariable args, int axis) {
|
||||
this(sameDiff, new SDVariable[]{args}, axis);
|
||||
}
|
||||
|
||||
public ExpandDims(SameDiff sameDiff, SDVariable[] args, int axis) {
|
||||
super(null, sameDiff, args);
|
||||
if (axis == Integer.MAX_VALUE) {
|
||||
|
@ -63,6 +67,11 @@ public class ExpandDims extends DynamicCustomOp {
|
|||
super(null, inputs, outputs);
|
||||
}
|
||||
|
||||
public ExpandDims(INDArray input, int axis) {
|
||||
addInputArgument(input);
|
||||
addIArgument(axis);
|
||||
}
|
||||
|
||||
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
|
||||
super(null, sameDiff, args, inPlace);
|
||||
}
|
||||
|
|
|
@ -122,6 +122,13 @@ public class Eye extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols, DataType dataType, int[] batchDimension) {
|
||||
super(null, sameDiff, new SDVariable[] {numRows, numCols}, false);
|
||||
this.batchDimension = batchDimension;
|
||||
this.dataType = dataType;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
protected void addArgs() {
|
||||
iArguments.clear();
|
||||
tArguments.clear();
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -40,6 +41,13 @@ public class Gather extends DynamicCustomOp {
|
|||
protected int[] indices;
|
||||
protected int jaxis = 0;
|
||||
|
||||
public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) {
|
||||
this(sameDiff, df, indices, axis, false);
|
||||
}
|
||||
|
||||
public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) {
|
||||
this(sameDiff, df, indices, axis, false);
|
||||
}
|
||||
|
||||
public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) {
|
||||
super(null, sameDiff, new SDVariable[] {input}, inPlace);
|
||||
|
@ -56,6 +64,21 @@ public class Gather extends DynamicCustomOp {
|
|||
this.jaxis = axis;
|
||||
}
|
||||
|
||||
public Gather(INDArray df, int[] indexes, int axis) {
|
||||
addInputArgument(df);
|
||||
addIArgument(axis);
|
||||
addIArgument(indexes);
|
||||
this.jaxis = axis;
|
||||
this.indices = indices;
|
||||
}
|
||||
|
||||
public Gather(INDArray df, INDArray indexes, int axis) {
|
||||
addInputArgument(df, indexes);
|
||||
addIArgument(axis);
|
||||
this.jaxis = axis;
|
||||
this.indices = indices;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
return "Gather";
|
||||
|
|
|
@ -17,10 +17,13 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
@ -31,11 +34,19 @@ import java.util.List;
|
|||
@NoArgsConstructor
|
||||
public class GatherNd extends DynamicCustomOp {
|
||||
|
||||
public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) {
|
||||
super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false);
|
||||
}
|
||||
|
||||
public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) {
|
||||
super(null, sameDiff, new SDVariable[] {input, indices}, inPlace);
|
||||
}
|
||||
|
||||
public GatherNd(INDArray[] df, INDArray[] indices) {
|
||||
addInputArgument(df);
|
||||
addInputArgument(indices);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "gather_nd";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -39,11 +40,24 @@ public class Linspace extends DynamicCustomOp {
|
|||
|
||||
private DataType dataType;
|
||||
|
||||
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
|
||||
super(sameDiff, new SDVariable[0]);
|
||||
addTArgument(start,stop);
|
||||
addIArgument(number);
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){
|
||||
super(sameDiff, new SDVariable[]{from, to, length});
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
public Linspace(DataType dataType, double start, double stop, long number) {
|
||||
addDArgument(dataType);
|
||||
addTArgument(start, stop);
|
||||
addIArgument(number);
|
||||
}
|
||||
|
||||
public Linspace(){ }
|
||||
|
||||
@Override
|
||||
|
|
|
@ -37,6 +37,10 @@ public class MeshGrid extends DynamicCustomOp {
|
|||
addIArgument(cartesian ? 1 : 0);
|
||||
}
|
||||
|
||||
public MeshGrid(SameDiff sd, SDVariable[] inputs, boolean cartesian) {
|
||||
this(sd, cartesian, inputs);
|
||||
}
|
||||
|
||||
public MeshGrid(){ }
|
||||
|
||||
@Override
|
||||
|
|
|
@ -66,6 +66,11 @@ public class OneHot extends DynamicCustomOp {
|
|||
this(indices, output, depth, -1, 1, 0);
|
||||
}
|
||||
|
||||
public OneHot(INDArray indices, int depth) {
|
||||
addInputArgument(indices);
|
||||
addIArgument(depth);
|
||||
}
|
||||
|
||||
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
|
||||
super(null, indices, output, null, null);
|
||||
this.depth = depth;
|
||||
|
@ -75,6 +80,12 @@ public class OneHot extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
|
||||
addInputArgument(indices);
|
||||
addIArgument(depth, axis);
|
||||
addTArgument(on, off);
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -48,10 +48,18 @@ public class OnesLike extends DynamicCustomOp {
|
|||
public OnesLike() {
|
||||
}
|
||||
|
||||
public OnesLike(SameDiff sameDiff, SDVariable input) {
|
||||
this(null, sameDiff, input);
|
||||
}
|
||||
|
||||
public OnesLike(String name, SameDiff sameDiff, SDVariable input) {
|
||||
this(name, sameDiff, input, input.dataType());
|
||||
}
|
||||
|
||||
public OnesLike(SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||
this(null, sameDiff, input, dataType);
|
||||
}
|
||||
|
||||
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||
super(name, sameDiff, new SDVariable[]{input}, false);
|
||||
this.outputType = dataType;
|
||||
|
|
|
@ -55,6 +55,11 @@ public class Permute extends Transpose {
|
|||
addIArgument(permuteDims);
|
||||
}
|
||||
|
||||
public Permute(INDArray input, int... permuteDims){
|
||||
addInputArgument(input);
|
||||
addIArgument(permuteDims);
|
||||
}
|
||||
|
||||
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
|
||||
super(sd, input, permuteDims);
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.*;
|
||||
|
@ -39,10 +40,18 @@ public class Rank extends DynamicCustomOp {
|
|||
public Rank() {
|
||||
}
|
||||
|
||||
public Rank(SameDiff sameDiff, SDVariable input) {
|
||||
this(sameDiff, input, false);
|
||||
}
|
||||
|
||||
public Rank(SameDiff sameDiff, SDVariable input, boolean inPlace) {
|
||||
super(null, sameDiff, new SDVariable[] {input}, inPlace);
|
||||
}
|
||||
|
||||
public Rank(INDArray indArray) {
|
||||
addInputArgument(indArray);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
|
|
|
@ -59,6 +59,10 @@ public class Reshape extends DynamicCustomOp {
|
|||
super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null);
|
||||
}
|
||||
|
||||
public Reshape(INDArray in, INDArray shape) {
|
||||
addInputArgument(in, shape);
|
||||
}
|
||||
|
||||
public Reshape() {
|
||||
}
|
||||
|
||||
|
|
|
@ -69,7 +69,13 @@ public class SequenceMask extends DynamicCustomOp {
|
|||
addIArgument(maxLen);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
}
|
||||
|
||||
public SequenceMask(INDArray input, DataType dataType) {
|
||||
addInputArgument(input);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue