Fix incompatibilities with generated code (#303)

* Cholesky fixed

* Constructors added

* MatMul wrapper

* Constructor added

* Missing wrappers added

* Generate Linalg namespace added

* Output data types

* Unit tests

* Added mmul

* Code generation

* Code generated

* Build fixed

* Fixing signatures

* Tests fixed

* Tests fixed

* Added enum

* Fix tests

* Some fixes

* Eye test fixed

* SameDiff: small fix for renameVariable - also replace variable name in lossVariable list if necessary

Signed-off-by: Alex Black <blacka101@gmail.com>

* Some fixes

* Tests fixed

* Revert wrong fix

* Some fixes

* Some fixes

* Extending base test class

* Added pad

* Fixed for generated signatures

* Fixes due to nd4j codegen

* Backwards compatibility fixes

* Fixed errors in tests, reverted wrong changes

* Test fixed

* Added missing operations used for nd4s operators

* Compilation fixed

* Added meshgrid

* Fixed constructors

* fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix bad commit (incorrectly reverted change from master)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixed test

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Alexander Stoyakin 2020-04-01 04:00:38 +03:00 committed by GitHub
parent 1d004b542a
commit 0a27e9f41d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
230 changed files with 12090 additions and 6072 deletions

View File

@ -130,14 +130,6 @@ public class SameDiffConv extends SameDiffLayer {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
SDVariable[] vars;
if(hasBias){
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
vars = new SDVariable[]{layerInput, w, b};
} else {
vars = new SDVariable[]{layerInput, w};
}
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kernel[0]).kW(kernel[1]) .kH(kernel[0]).kW(kernel[1])
.pH(padding[0]).pW(padding[1]) .pH(padding[0]).pW(padding[1])
@ -146,7 +138,13 @@ public class SameDiffConv extends SameDiffLayer {
.isSameMode(this.cm == ConvolutionMode.Same) .isSameMode(this.cm == ConvolutionMode.Same)
.build(); .build();
SDVariable conv = sameDiff.cnn().conv2d(vars, c); //TODO can't set name SDVariable conv = null;
if(hasBias){
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
conv = sameDiff.cnn().conv2d(layerInput, w, b, c);
} else {
conv = sameDiff.cnn().conv2d(layerInput, w, c);
}
return activation.asSameDiff("out", sameDiff, conv); return activation.asSameDiff("out", sameDiff, conv);
} }

View File

@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Map; import java.util.Map;
@ -99,15 +100,15 @@ public class CapsuleLayer extends SameDiffLayer {
} }
@Override @Override
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) { public SDVariable defineLayer(SameDiff sd, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
// input: [mb, inputCapsules, inputCapsuleDimensions] // input: [mb, inputCapsules, inputCapsuleDimensions]
// [mb, inputCapsules, 1, inputCapsuleDimensions, 1] // [mb, inputCapsules, 1, inputCapsuleDimensions, 1]
SDVariable expanded = SD.expandDims(SD.expandDims(input, 2), 4); SDVariable expanded = sd.expandDims(sd.expandDims(input, 2), 4);
// [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1] // [mb, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1]
SDVariable tiled = SD.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1); SDVariable tiled = sd.tile(expanded, 1, 1, capsules * capsuleDimensions, 1, 1);
// [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions] // [1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions]
SDVariable weights = paramTable.get(WEIGHT_PARAM); SDVariable weights = paramTable.get(WEIGHT_PARAM);
@ -119,13 +120,13 @@ public class CapsuleLayer extends SameDiffLayer {
// b is the logits of the routing procedure // b is the logits of the routing procedure
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
SDVariable b = SD.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1)); SDVariable b = sd.zerosLike(uHat).get(SDIndex.all(), SDIndex.all(), SDIndex.all(), SDIndex.interval(0, 1), SDIndex.interval(0, 1));
for(int i = 0 ; i < routings ; i++){ for(int i = 0 ; i < routings ; i++){
// c is the coupling coefficient, i.e. the edge weight between the 2 capsules // c is the coupling coefficient, i.e. the edge weight between the 2 capsules
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
SDVariable c = CapsuleUtils.softmax(SD, b, 2, 5); SDVariable c = sd.nn.softmax(b, 2);
// [mb, 1, capsules, capsuleDimensions, 1] // [mb, 1, capsules, capsuleDimensions, 1]
SDVariable s = c.times(uHat).sum(true, 1); SDVariable s = c.times(uHat).sum(true, 1);
@ -135,14 +136,14 @@ public class CapsuleLayer extends SameDiffLayer {
// v is the per capsule activations. On the last routing iteration, this is output // v is the per capsule activations. On the last routing iteration, this is output
// [mb, 1, capsules, capsuleDimensions, 1] // [mb, 1, capsules, capsuleDimensions, 1]
SDVariable v = CapsuleUtils.squash(SD, s, 3); SDVariable v = CapsuleUtils.squash(sd, s, 3);
if(i == routings - 1){ if(i == routings - 1){
return SD.squeeze(SD.squeeze(v, 1), 3); return sd.squeeze(sd.squeeze(v, 1), 3);
} }
// [mb, inputCapsules, capsules, capsuleDimensions, 1] // [mb, inputCapsules, capsules, capsuleDimensions, 1]
SDVariable vTiled = SD.tile(v, 1, (int) inputCapsules, 1, 1, 1); SDVariable vTiled = sd.tile(v, 1, (int) inputCapsules, 1, 1, 1);
// [mb, inputCapsules, capsules, 1, 1] // [mb, inputCapsules, capsules, 1, 1]
b = b.plus(uHat.times(vTiled).sum(true, 3)); b = b.plus(uHat.times(vTiled).sum(true, 3));

View File

@ -178,9 +178,11 @@ public class LocallyConnected1D extends SameDiffLayer {
//Note: for same mode, bottom/right padding can be 1 more than top/left padding //Note: for same mode, bottom/right padding can be 1 more than top/left padding
//NCW format. //NCW format.
if(cm == ConvolutionMode.Same) { if(cm == ConvolutionMode.Same) {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, paddingR}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0);
} else { } else {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0, 0}, {0, 0}, {padding, padding}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0);
} }
} }

View File

@ -184,9 +184,11 @@ public class LocallyConnected2D extends SameDiffLayer {
//Note: for same mode, bottom/right padding can be 1 more than top/left padding //Note: for same mode, bottom/right padding can be 1 more than top/left padding
//NCHW format //NCHW format
if(cm == ConvolutionMode.Same){ if(cm == ConvolutionMode.Same){
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0);
} else { } else {
layerInput = sameDiff.nn().pad(layerInput, new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}}, 0); layerInput = sameDiff.nn().pad(layerInput,
sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0);
} }
} }

View File

@ -45,15 +45,4 @@ public class CapsuleUtils {
return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale)); return x.times(squaredNorm).div(squaredNorm.plus(1.0).times(scale));
} }
/**
* Compute softmax along a given dimension
*/
public static SDVariable softmax(SameDiff SD, SDVariable x, int dimension, int rank){
int[] permutation = ArrayUtil.range(0, rank);
permutation[0] = dimension;
permutation[dimension] = 0;
return SD.nn.softmax(x.permute(permutation)).permute(ArrayUtil.invertPermutation(permutation));
}
} }

View File

@ -495,7 +495,7 @@ public class JsonModelServerTest extends BaseDL4JTest {
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28); SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 28*28);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10)); SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 28*28, 10));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10)); SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 10));
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b)); SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b), -1);
val server = new JsonModelServer.Builder<float[], Integer>(sd) val server = new JsonModelServer.Builder<float[], Integer>(sd)
.outputSerializer( new IntSerde()) .outputSerializer( new IntSerde())

View File

@ -58,7 +58,7 @@ public class TestSameDiffUI extends BaseDL4JTest {
SDVariable b = sd.var("b", DataType.FLOAT, 1, 4); SDVariable b = sd.var("b", DataType.FLOAT, 1, 4);
SDVariable z = in.mmul(w).add(b); SDVariable z = in.mmul(w).add(b);
SDVariable a = sd.nn().tanh(z); SDVariable a = sd.math().tanh(z);
LogFileWriter lfw = new LogFileWriter(f); LogFileWriter lfw = new LogFileWriter(f);
lfw.writeGraphStructure(sd); lfw.writeGraphStructure(sd);

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter; import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.ModelType;
import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.integration.TestCase;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig; import org.nd4j.autodiff.samediff.TrainingConfig;

View File

@ -28,6 +28,7 @@ import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.DataFormat;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -1489,7 +1490,7 @@ public class DifferentialFunctionFactory {
} }
public SDVariable reciprocal(SDVariable a) { public SDVariable reciprocal(SDVariable a) {
return new Reciprocal(sameDiff(), a, false).outputVariable(); return new Reciprocal(sameDiff(), a).outputVariable();
} }
@ -1990,13 +1991,13 @@ public class DifferentialFunctionFactory {
.outputVariable(); .outputVariable();
} }
public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, String dataFormat) { public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) {
validateDifferentialFunctionsameDiff(differentialFunction); validateDifferentialFunctionsameDiff(differentialFunction);
return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat)
.outputVariable(); .outputVariable();
} }
public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, String dataFormat) { public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) {
validateDifferentialFunctionsameDiff(differentialFunction); validateDifferentialFunctionsameDiff(differentialFunction);
return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat)
.outputVariable(); .outputVariable();
@ -2635,7 +2636,7 @@ public class DifferentialFunctionFactory {
return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable();
} }
public SDVariable[] maxPoolWithArgmaxs(SDVariable x, Pooling2DConfig pooling2DConfig) { public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) {
return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables();
} }

View File

@ -181,6 +181,11 @@ public class SameDiff extends SDBaseOps {
*/ */
public final SDBitwise bitwise = new SDBitwise(this); public final SDBitwise bitwise = new SDBitwise(this);
/**
* Op creator object for linalg operations
*/
public final SDLinalg linalg = new SDLinalg(this);
/** /**
* Op creator object for math operations * Op creator object for math operations
*/ */
@ -237,6 +242,13 @@ public class SameDiff extends SDBaseOps {
return bitwise; return bitwise;
} }
/**
* Op creator object for linalg operations
*/
public SDLinalg linalg(){
return linalg;
}
private Map<String, SameDiff> sameDiffFunctionInstances; private Map<String, SameDiff> sameDiffFunctionInstances;
private Table<String, String, String> fieldVariableResolutionMapping; private Table<String, String, String> fieldVariableResolutionMapping;
@ -3448,6 +3460,12 @@ public class SameDiff extends SDBaseOps {
sd.renameVariable(from, to); sd.renameVariable(from, to);
} }
} }
//Check losses:
if(lossVariables.contains(from)){
int idx = lossVariables.indexOf(from);
lossVariables.set(idx, to);
}
} }

View File

@ -1,217 +1,416 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull; import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; public class SDBitwise extends SDOps {
public SDBitwise(SameDiff sameDiff) {
super(sameDiff);
}
/** /**
* * Bitwise AND operation. Supports broadcasting.<br>
*/ *
public class SDBitwise extends SDOps { * Inputs must satisfy the following constraints: <br>
public SDBitwise(SameDiff sameDiff) { * Must be same types: isSameType(x, y)<br>
super(sameDiff); * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
} *
* @param x First input array (INT type)
* @param y Second input array (INT type)
* @return output Bitwise AND array (INT type)
*/
public SDVariable and(SDVariable x, SDVariable y) {
SDValidation.validateInteger("and", "x", x);
SDValidation.validateInteger("and", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable();
}
/** /**
* See {@link #leftShift(String, SDVariable, SDVariable)} * Bitwise AND operation. Supports broadcasting.<br>
*/ *
public SDVariable leftShift(@NonNull SDVariable x, @NonNull SDVariable y){ * Inputs must satisfy the following constraints: <br>
return leftShift(null, x, y); * Must be same types: isSameType(x, y)<br>
} * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param name name May be null. Name for the output variable
* @param x First input array (INT type)
* @param y Second input array (INT type)
* @return output Bitwise AND array (INT type)
*/
public SDVariable and(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("and", "x", x);
SDValidation.validateInteger("and", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Bitwise left shift operation. Supports broadcasting. * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
* *
* @param name Name of the output variable. May be null. * @param x Input 1 (INT type)
* @param x Input to be bit shifted (must be an integer type) * @param shift Number of bits to shift. (INT type)
* @param y Amount to shift elements of x array (must be an integer type) * @return output SDVariable with shifted bits (INT type)
* @return Bitwise shifted input x */
*/ public SDVariable bitRotl(SDVariable x, SDVariable shift) {
public SDVariable leftShift(String name, SDVariable x, SDVariable y){ SDValidation.validateInteger("bitRotl", "x", x);
validateInteger("bitwise left shift", x); SDValidation.validateInteger("bitRotl", "shift", shift);
validateInteger("bitwise left shift", y); return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
}
SDVariable ret = f().shift(x, y); /**
return updateVariableNameAndReference(ret, name); * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)<br>
} *
* @param name name May be null. Name for the output variable
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotl", "x", x);
SDValidation.validateInteger("bitRotl", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* See {@link #rightShift(String, SDVariable, SDVariable)} * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
*/ *
public SDVariable rightShift(SDVariable x, SDVariable y){ * @param x Input 1 (INT type)
return rightShift(null, x, y); * @param shift Number of bits to shift. (INT type)
} * @return output SDVariable with shifted bits (INT type)
*/
public SDVariable bitRotr(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitRotr", "x", x);
SDValidation.validateInteger("bitRotr", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
}
/** /**
* Bitwise right shift operation. Supports broadcasting. * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)<br>
* *
* @param name Name of the output variable. May be null. * @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (must be an integer type) * @param x Input 1 (INT type)
* @param y Amount to shift elements of x array (must be an integer type) * @param shift Number of bits to shift. (INT type)
* @return Bitwise shifted input x * @return output SDVariable with shifted bits (INT type)
*/ */
public SDVariable rightShift(String name, SDVariable x, SDVariable y){ public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
validateInteger("bitwise right shift", x); SDValidation.validateInteger("bitRotr", "x", x);
validateInteger("bitwise right shift", y); SDValidation.validateInteger("bitRotr", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
SDVariable ret = f().rshift(x, y); /**
return updateVariableNameAndReference(ret, name); * Shift integer bits to the left, i.e. var << 4<br>
} *
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public SDVariable bitShift(SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShift", "x", x);
SDValidation.validateInteger("bitShift", "shift", shift);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
}
/** /**
* See {@link #leftShiftCyclic(String, SDVariable, SDVariable)} * Shift integer bits to the left, i.e. var << 4<br>
*/ *
public SDVariable leftShiftCyclic(SDVariable x, SDVariable y){ * @param name name May be null. Name for the output variable
return leftShiftCyclic(null, x, y); * @param x Input 1 (INT type)
} * @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShift", "x", x);
SDValidation.validateInteger("bitShift", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Bitwise left cyclical shift operation. Supports broadcasting. * Shift integer bits to the right, i.e. var >> 4<br>
* Unlike {@link #leftShift(String, SDVariable, SDVariable)} the bits will "wrap around": *
* {@code leftShiftCyclic(01110000, 2) -> 11000001} * @param x Input 1 (INT type)
* * @param shift Number of bits to shift. (INT type)
* @param name Name of the output variable. May be null. * @return output SDVariable with shifted bits (INT type)
* @param x Input to be bit shifted (must be an integer type) */
* @param y Amount to shift elements of x array (must be an integer type) public SDVariable bitShiftRight(SDVariable x, SDVariable shift) {
* @return Bitwise cyclic shifted input x SDValidation.validateInteger("bitShiftRight", "x", x);
*/ SDValidation.validateInteger("bitShiftRight", "shift", shift);
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y){ return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
validateInteger("bitwise left shift (cyclic)", x); }
validateInteger("bitwise left shift (cyclic)", y);
SDVariable ret = f().rotl(x, y); /**
return updateVariableNameAndReference(ret, name); * Shift integer bits to the right, i.e. var >> 4<br>
} *
* @param name name May be null. Name for the output variable
* @param x Input 1 (INT type)
* @param shift Number of bits to shift. (INT type)
* @return output SDVariable with shifted bits (INT type)
*/
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
SDValidation.validateInteger("bitShiftRight", "x", x);
SDValidation.validateInteger("bitShiftRight", "shift", shift);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* See {@link #rightShiftCyclic(String, SDVariable, SDVariable)} * Bitwise Hamming distance reduction over all elements of both input arrays.<br>
*/ * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
public SDVariable rightShiftCyclic(SDVariable x, SDVariable y){ *
return rightShiftCyclic(null, x, y); * Inputs must satisfy the following constraints: <br>
} * Must be same types: isSameType(x, y)<br>
*
* @param x First input array. (INT type)
* @param y Second input array. (INT type)
* @return output bitwise Hamming distance (INT type)
*/
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y) {
SDValidation.validateInteger("bitsHammingDistance", "x", x);
SDValidation.validateInteger("bitsHammingDistance", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
}
/** /**
* Bitwise right cyclical shift operation. Supports broadcasting. * Bitwise Hamming distance reduction over all elements of both input arrays.<br>
* Unlike {@link #rightShift(String, SDVariable, SDVariable)} the bits will "wrap around": * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011} *
* * Inputs must satisfy the following constraints: <br>
* @param name Name of the output variable. May be null. * Must be same types: isSameType(x, y)<br>
* @param x Input to be bit shifted (must be an integer type) *
* @param y Amount to shift elements of x array (must be an integer type) * @param name name May be null. Name for the output variable
* @return Bitwise cyclic shifted input x * @param x First input array. (INT type)
*/ * @param y Second input array. (INT type)
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y){ * @return output bitwise Hamming distance (INT type)
validateInteger("bitwise right shift (cyclic)", x); */
validateInteger("bitwise right shift (cyclic)", y); public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("bitsHammingDistance", "x", x);
SDValidation.validateInteger("bitsHammingDistance", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
SDVariable ret = f().rotr(x, y); /**
return updateVariableNameAndReference(ret, name); * Bitwise left shift operation. Supports broadcasting.<br>
} *
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public SDVariable leftShift(SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShift", "x", x);
SDValidation.validateInteger("leftShift", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable();
}
/** /**
* See {@link #bitsHammingDistance(String, SDVariable, SDVariable)} * Bitwise left shift operation. Supports broadcasting.<br>
*/ *
public SDVariable bitsHammingDistance(SDVariable x, SDVariable y){ * @param name name May be null. Name for the output variable
return bitsHammingDistance(null, x, y); * @param x Input to be bit shifted (INT type)
} * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public SDVariable leftShift(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShift", "x", x);
SDValidation.validateInteger("leftShift", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Bitwise Hamming distance reduction over all elements of both input arrays.<br> * Bitwise left cyclical shift operation. Supports broadcasting.<br>
* For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1) * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* * {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
* @param name Name of the output variable. May be null. *
* @param x First input array. Must be integer type. * @param x Input to be bit shifted (INT type)
* @param y First input array. Must be integer type, same type as x * @param y Amount to shift elements of x array (INT type)
* @return * @return output Bitwise cyclic shifted input x (INT type)
*/ */
public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y){ public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) {
validateInteger("bitwise hamming distance", x); SDValidation.validateInteger("leftShiftCyclic", "x", x);
validateInteger("bitwise hamming distance", y); SDValidation.validateInteger("leftShiftCyclic", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
}
SDVariable ret = f().bitwiseHammingDist(x, y); /**
return updateVariableNameAndReference(ret, name); * Bitwise left cyclical shift operation. Supports broadcasting.<br>
} * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
*
* @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type)
*/
public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("leftShiftCyclic", "x", x);
SDValidation.validateInteger("leftShiftCyclic", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* See {@link #and(String, SDVariable, SDVariable)} * Bitwise OR operation. Supports broadcasting.<br>
*/ *
public SDVariable and(SDVariable x, SDVariable y){ * Inputs must satisfy the following constraints: <br>
return and(null, x, y); * Must be same types: isSameType(x, y)<br>
} * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param x First input array (INT type)
* @param y First input array (INT type)
* @return output Bitwise OR array (INT type)
*/
public SDVariable or(SDVariable x, SDVariable y) {
SDValidation.validateInteger("or", "x", x);
SDValidation.validateInteger("or", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable();
}
/** /**
* Bitwise AND operation. Supports broadcasting. * Bitwise OR operation. Supports broadcasting.<br>
* *
* @param name Name of the output variable. May be null. * Inputs must satisfy the following constraints: <br>
* @param x First input array. Must be integer type. * Must be same types: isSameType(x, y)<br>
* @param y First input array. Must be integer type, same type as x * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* @return Bitwise AND array *
*/ * @param name name May be null. Name for the output variable
public SDVariable and(String name, SDVariable x, SDVariable y){ * @param x First input array (INT type)
validateInteger("bitwise AND", x); * @param y First input array (INT type)
validateInteger("bitwise AND", y); * @return output Bitwise OR array (INT type)
*/
public SDVariable or(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("or", "x", x);
SDValidation.validateInteger("or", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
SDVariable ret = f().bitwiseAnd(x, y); /**
return updateVariableNameAndReference(ret, name); * Bitwise right shift operation. Supports broadcasting. <br>
} *
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public SDVariable rightShift(SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShift", "x", x);
SDValidation.validateInteger("rightShift", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable();
}
/** /**
* See {@link #or(String, SDVariable, SDVariable)} * Bitwise right shift operation. Supports broadcasting. <br>
*/ *
public SDVariable or(SDVariable x, SDVariable y){ * @param name name May be null. Name for the output variable
return or(null, x, y); * @param x Input to be bit shifted (INT type)
} * @param y Amount to shift elements of x array (INT type)
* @return output Bitwise shifted input x (INT type)
*/
public SDVariable rightShift(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShift", "x", x);
SDValidation.validateInteger("rightShift", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Bitwise OR operation. Supports broadcasting. * Bitwise right cyclical shift operation. Supports broadcasting.<br>
* * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* @param name Name of the output variable. May be null. * {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
* @param x First input array. Must be integer type. *
* @param y First input array. Must be integer type, same type as x * @param x Input to be bit shifted (INT type)
* @return Bitwise OR array * @param y Amount to shift elements of x array (INT type)
*/ * @return output Bitwise cyclic shifted input x (INT type)
public SDVariable or(String name, SDVariable x, SDVariable y){ */
validateInteger("bitwise OR", x); public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) {
validateInteger("bitwise OR", y); SDValidation.validateInteger("rightShiftCyclic", "x", x);
SDValidation.validateInteger("rightShiftCyclic", "y", y);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
}
SDVariable ret = f().bitwiseOr(x, y); /**
return updateVariableNameAndReference(ret, name); * Bitwise right cyclical shift operation. Supports broadcasting.<br>
} * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
*
* @param name name May be null. Name for the output variable
* @param x Input to be bit shifted (INT type)
* @param y Amount to shift elements of x array (INT type)
* @return output Bitwise cyclic shifted input x (INT type)
*/
public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) {
SDValidation.validateInteger("rightShiftCyclic", "x", x);
SDValidation.validateInteger("rightShiftCyclic", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* See {@link #xor(String, SDVariable, SDVariable)} * Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
*/ *
public SDVariable xor(SDVariable x, SDVariable y){ * Inputs must satisfy the following constraints: <br>
return xor(null, x, y); * Must be same types: isSameType(x, y)<br>
} * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
*
* @param x First input array (INT type)
* @param y First input array (INT type)
* @return output Bitwise XOR array (INT type)
*/
public SDVariable xor(SDVariable x, SDVariable y) {
SDValidation.validateInteger("xor", "x", x);
SDValidation.validateInteger("xor", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable();
}
/** /**
* Bitwise XOR operation (exclusive OR). Supports broadcasting. * Bitwise XOR operation (exclusive OR). Supports broadcasting.<br>
* *
* @param name Name of the output variable. May be null. * Inputs must satisfy the following constraints: <br>
* @param x First input array. Must be integer type. * Must be same types: isSameType(x, y)<br>
* @param y First input array. Must be integer type, same type as x * Must have broadcastable shapes: isBroadcastableShapes(x, y)<br>
* @return Bitwise XOR array *
*/ * @param name name May be null. Name for the output variable
public SDVariable xor(String name, SDVariable x, SDVariable y){ * @param x First input array (INT type)
validateInteger("bitwise XOR", x); * @param y First input array (INT type)
validateInteger("bitwise XOR", y); * @return output Bitwise XOR array (INT type)
*/
SDVariable ret = f().bitwiseXor(x, y); public SDVariable xor(String name, SDVariable x, SDVariable y) {
return updateVariableNameAndReference(ret, name); SDValidation.validateInteger("xor", "x", x);
} SDValidation.validateInteger("xor", "y", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
/** SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable();
* Flip bits return sd.updateVariableNameAndReference(out, name);
* }
* @param name Name of the output variable
* @param x input array
* @return array after flipping each input bit
*/
public SDVariable toggleBits(String name, SDVariable x) {
SDVariable res = f().toggleBits(x);
return updateVariableNameAndReference(res, name);
}
} }

View File

@ -1,185 +1,440 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull; import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
/**
* @author Alex Black
*/
public class SDImage extends SDOps { public class SDImage extends SDOps {
public SDImage(SameDiff sameDiff) { public SDImage(SameDiff sameDiff) {
super(sameDiff); super(sameDiff);
} }
/** /**
* Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size. * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
* *
* @param name May be null. Name for the output variable. * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @param image Input image, with shape [batch, height, width, channels] * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
* @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
* @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
* @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
* @param method Image resize method * @return output Cropped and resized images (NUMERIC type)
* @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default */
* @return Cropped and resized images public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
*/ SDVariable cropOutSize, double extrapolationValue) {
public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices, SDVariable cropOutSize, SDValidation.validateNumerical("CropAndResize", "image", image);
CropAndResize.Method method, double extrapolationValue) { SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
SDVariable out = new CropAndResize(sd, image, cropBoxes, boxIndices, cropOutSize, method, extrapolationValue).outputVariable(); SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
return updateVariableNameAndReference(out, name); SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
} return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable();
}
/** /**
* Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
* *
* @param name Map be null. Name for the output variable * @param name name May be null. Name for the output variable
* @param image Input image to extract image patches from - shape [batch, height, width, channels] * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @param kSizes Kernel size - size of the image patches, [height, width] * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
* @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
* @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
* in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken * @param extrapolationValue Used for extrapolation, when applicable. 0.0 should be used for the default
* along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension * @return output Cropped and resized images (NUMERIC type)
* @param sameMode Padding algorithm. If true: use Same padding */
* @return The extracted image patches public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes,
*/ SDVariable boxIndices, SDVariable cropOutSize, double extrapolationValue) {
public SDVariable extractImagePatches(String name, SDVariable image, @NonNull int[] kSizes, SDValidation.validateNumerical("CropAndResize", "image", image);
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode) { SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
SDVariable out = new ExtractImagePatches(sd, image, kSizes, strides, rates, sameMode).outputVariable(); SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
return updateVariableNameAndReference(out, name); SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
} SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, extrapolationValue).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Greedily selects a subset of bounding boxes in descending order of score * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
* @param name Might be null. Name for the output variable *
* @param boxes 2D array of shape [num_boxes,4] * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @param scores vector of shape [num_boxes] * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
* @param maxOutSize scalar representing the maximum number of boxes to be selected * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
* @param iouThreshold float - threshold for deciding whether boxes overlap too much with respect to IOU * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
* @param scoreThreshold float - threshold for deciding when to remove boxes based on score * @return output Cropped and resized images (NUMERIC type)
* @return vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size */
*/ public SDVariable cropAndResize(SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
public SDVariable nonMaxSuppression(String name, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize, SDVariable cropOutSize) {
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold){ SDValidation.validateNumerical("CropAndResize", "image", image);
SDVariable out = new NonMaxSuppression(sd, boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable(); SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
return updateVariableNameAndReference(out, name); SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
} SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
return new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable();
}
/** /**
* Adjusts contrast of RGB or grayscale images. * Given an input image and some crop boxes, extract out the image subsets and resize them to the specified size.<br>
* @param name name for the output variable *
* @param in images to adjust. 3D shape or higher. * @param name name May be null. Name for the output variable
* @param factor float multiplier for adjusting contrast. * @param image Input image, with shape [batch, height, width, channels] (NUMERIC type)
* @return Contrast-adjusted image * @param cropBoxes Float32 crop, shape [numBoxes, 4] with values in range 0 to 1 (NUMERIC type)
*/ * @param boxIndices Indices: which image (index to dimension 0) the cropBoxes belong to. Rank 1, shape [numBoxes] (NUMERIC type)
public SDVariable adjustContrast(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { * @param cropOutSize Output size for the images - int32, rank 1 with values [outHeight, outWidth] (INT type)
SDVariable out = new AdjustContrast(sd, in, factor).outputVariable(); * @return output Cropped and resized images (NUMERIC type)
return updateVariableNameAndReference(out, name); */
} public SDVariable cropAndResize(String name, SDVariable image, SDVariable cropBoxes,
SDVariable boxIndices, SDVariable cropOutSize) {
SDValidation.validateNumerical("CropAndResize", "image", image);
SDValidation.validateNumerical("CropAndResize", "cropBoxes", cropBoxes);
SDValidation.validateNumerical("CropAndResize", "boxIndices", boxIndices);
SDValidation.validateInteger("CropAndResize", "cropOutSize", cropOutSize);
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.CropAndResize(sd,image, cropBoxes, boxIndices, cropOutSize, 0.0).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Adjust saturation of RGB images * Adjusts contrast of RGB or grayscale images.<br>
* @param name name for the output variable *
* @param in RGB image as 3D array * @param in images to adjust. 3D shape or higher (NUMERIC type)
* @param factor factor for saturation * @param factor multiplier for adjusting contrast
* @return adjusted image * @return output Contrast-adjusted image (NUMERIC type)
*/ */
public SDVariable adjustSaturation(String name, @NonNull SDVariable in, @NonNull SDVariable factor) { public SDVariable adjustContrast(SDVariable in, double factor) {
SDVariable out = new AdjustSaturation(sd, in, factor).outputVariable(); SDValidation.validateNumerical("adjustContrast", "in", in);
return updateVariableNameAndReference(out, name); return new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable();
} }
/** /**
* Adjust hue of RGB image * Adjusts contrast of RGB or grayscale images.<br>
* @param name name for the output variable *
* @param in RGB image as 3D array * @param name name May be null. Name for the output variable
* @param delta value to add to hue channel * @param in images to adjust. 3D shape or higher (NUMERIC type)
* @return adjusted image * @param factor multiplier for adjusting contrast
*/ * @return output Contrast-adjusted image (NUMERIC type)
public SDVariable adjustHue(String name, @NonNull SDVariable in, @NonNull SDVariable delta) { */
SDVariable out = new AdjustHue(sd, in, delta).outputVariable(); public SDVariable adjustContrast(String name, SDVariable in, double factor) {
return updateVariableNameAndReference(out, name); SDValidation.validateNumerical("adjustContrast", "in", in);
} SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustContrast(sd,in, factor).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Randomly crops image * Adjust hue of RGB image <br>
* @param name name for the output variable *
* @param input input array * @param in image as 3D array (NUMERIC type)
* @param shape shape for crop * @param delta value to add to hue channel
* @return cropped array * @return output adjusted image (NUMERIC type)
*/ */
public SDVariable randomCrop(String name, @NonNull SDVariable input, @NonNull SDVariable shape) { public SDVariable adjustHue(SDVariable in, double delta) {
SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); SDValidation.validateNumerical("adjustHue", "in", in);
return updateVariableNameAndReference(out, name); return new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable();
} }
/** /**
* Converting array from HSV to RGB format * Adjust hue of RGB image <br>
* @param name name *
* @param input 3D image * @param name name May be null. Name for the output variable
* @return 3D image * @param in image as 3D array (NUMERIC type)
*/ * @param delta value to add to hue channel
public SDVariable rgbToHsv(String name, @NonNull SDVariable input) { * @return output adjusted image (NUMERIC type)
SDVariable out = new RgbToHsv(sd, input).outputVariable(); */
return updateVariableNameAndReference(out, name); public SDVariable adjustHue(String name, SDVariable in, double delta) {
} SDValidation.validateNumerical("adjustHue", "in", in);
SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustHue(sd,in, delta).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Converting image from HSV to RGB format * Adjust saturation of RGB images<br>
* @param name name *
* @param input 3D image * @param in RGB image as 3D array (NUMERIC type)
* @return 3D image * @param factor factor for saturation
*/ * @return output adjusted image (NUMERIC type)
public SDVariable hsvToRgb(String name, @NonNull SDVariable input) { */
SDVariable out = new HsvToRgb(sd, input).outputVariable(); public SDVariable adjustSaturation(SDVariable in, double factor) {
return updateVariableNameAndReference(out, name); SDValidation.validateNumerical("adjustSaturation", "in", in);
} return new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable();
}
/** /**
* Converting array from RGB to YIQ format * Adjust saturation of RGB images<br>
* @param name name *
* @param input 3D image * @param name name May be null. Name for the output variable
* @return 3D image * @param in RGB image as 3D array (NUMERIC type)
*/ * @param factor factor for saturation
public SDVariable rgbToYiq(String name, @NonNull SDVariable input) { * @return output adjusted image (NUMERIC type)
SDVariable out = new RgbToYiq(sd, input).outputVariable(); */
return updateVariableNameAndReference(out, name); public SDVariable adjustSaturation(String name, SDVariable in, double factor) {
} SDValidation.validateNumerical("adjustSaturation", "in", in);
SDVariable out = new org.nd4j.linalg.api.ops.custom.AdjustSaturation(sd,in, factor).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Converting image from YIQ to RGB format * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. <br>
* @param name name *
* @param input 3D image * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type)
* @return 3D image * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2))
*/ * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2))
public SDVariable yiqToRgb(String name, @NonNull SDVariable input) { * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
SDVariable out = new YiqToRgb(sd, input).outputVariable(); * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
return updateVariableNameAndReference(out, name); * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0))
} * @param sameMode Padding algorithm. If true: use Same padding
* @return output The extracted image patches (NUMERIC type)
*/
public SDVariable extractImagePatches(SDVariable image, int[] kSizes, int[] strides, int[] rates,
boolean sameMode) {
SDValidation.validateNumerical("extractImagePatches", "image", image);
Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length);
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length);
return new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable();
}
/** /**
* Converting array from RGB to YUV format * Given an input image, extract out image patches (of size kSizes - h x w) and place them in the depth dimension. <br>
* @param name name *
* @param input 3D image * @param name name May be null. Name for the output variable
* @return 3D image * @param image Input image to extract image patches from - shape [batch, height, width, channels] (NUMERIC type)
*/ * @param kSizes Kernel size - size of the image patches, [height, width] (Size: Exactly(count=2))
public SDVariable rgbToYuv(String name, @NonNull SDVariable input) { * @param strides Stride in the input dimension for extracting image patches, [stride_height, stride_width] (Size: Exactly(count=2))
SDVariable out = new RgbToYuv(sd, input).outputVariable(); * @param rates Usually [1,1]. Equivalent to dilation rate in dilated convolutions - how far apart the output pixels
return updateVariableNameAndReference(out, name); * in the patches should be, in the input. A dilation of [a,b] means every {@code a}th pixel is taken
} * along the height/rows dimension, and every {@code b}th pixel is take along the width/columns dimension (Size: AtLeast(min=0))
* @param sameMode Padding algorithm. If true: use Same padding
* @return output The extracted image patches (NUMERIC type)
*/
public SDVariable extractImagePatches(String name, SDVariable image, int[] kSizes, int[] strides,
int[] rates, boolean sameMode) {
SDValidation.validateNumerical("extractImagePatches", "image", image);
Preconditions.checkArgument(kSizes.length == 2, "kSizes has incorrect size/length. Expected: kSizes.length == 2, got %s", kSizes.length);
Preconditions.checkArgument(strides.length == 2, "strides has incorrect size/length. Expected: strides.length == 2, got %s", strides.length);
Preconditions.checkArgument(rates.length >= 0, "rates has incorrect size/length. Expected: rates.length >= 0, got %s", rates.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches(sd,image, kSizes, strides, rates, sameMode).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Converting image from YUV to RGB format * Converting image from HSV to RGB format <br>
* @param name name *
* @param input 3D image * @param input 3D image (NUMERIC type)
* @return 3D image * @return output 3D image (NUMERIC type)
*/ */
public SDVariable yuvToRgb(String name, @NonNull SDVariable input) { public SDVariable hsvToRgb(SDVariable input) {
SDVariable out = new YuvToRgb(sd, input).outputVariable(); SDValidation.validateNumerical("hsvToRgb", "input", input);
return updateVariableNameAndReference(out, name); return new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable();
} }
/**
* Converting image from HSV to RGB format <br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable hsvToRgb(String name, SDVariable input) {
SDValidation.validateNumerical("hsvToRgb", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.HsvToRgb(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Greedily selects a subset of bounding boxes in descending order of score<br>
*
* @param boxes Might be null. Name for the output variable (NUMERIC type)
* @param scores vector of shape [num_boxes] (NUMERIC type)
* @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold threshold for deciding when to remove boxes based on score
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
*/
public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize,
double iouThreshold, double scoreThreshold) {
SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes);
SDValidation.validateNumerical("nonMaxSuppression", "scores", scores);
return new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
}
/**
* Greedily selects a subset of bounding boxes in descending order of score<br>
*
* @param name name May be null. Name for the output variable
* @param boxes Might be null. Name for the output variable (NUMERIC type)
* @param scores vector of shape [num_boxes] (NUMERIC type)
* @param maxOutSize scalar representing the maximum number of boxes to be selected
* @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU
* @param scoreThreshold threshold for deciding when to remove boxes based on score
* @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type)
*/
public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores,
int maxOutSize, double iouThreshold, double scoreThreshold) {
SDValidation.validateNumerical("nonMaxSuppression", "boxes", boxes);
SDValidation.validateNumerical("nonMaxSuppression", "scores", scores);
SDVariable out = new org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression(sd,boxes, scores, maxOutSize, iouThreshold, scoreThreshold).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Randomly crops image<br>
*
* @param input input array (NUMERIC type)
* @param shape shape for crop (INT type)
* @return output cropped array (NUMERIC type)
*/
public SDVariable randomCrop(SDVariable input, SDVariable shape) {
SDValidation.validateNumerical("randomCrop", "input", input);
SDValidation.validateInteger("randomCrop", "shape", shape);
return new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable();
}
/**
* Randomly crops image<br>
*
* @param name name May be null. Name for the output variable
* @param input input array (NUMERIC type)
* @param shape shape for crop (INT type)
* @return output cropped array (NUMERIC type)
*/
public SDVariable randomCrop(String name, SDVariable input, SDVariable shape) {
SDValidation.validateNumerical("randomCrop", "input", input);
SDValidation.validateInteger("randomCrop", "shape", shape);
SDVariable out = new org.nd4j.linalg.api.ops.custom.RandomCrop(sd,input, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Converting array from HSV to RGB format<br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToHsv(SDVariable input) {
SDValidation.validateNumerical("rgbToHsv", "input", input);
return new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable();
}
/**
* Converting array from HSV to RGB format<br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToHsv(String name, SDVariable input) {
SDValidation.validateNumerical("rgbToHsv", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToHsv(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Converting array from RGB to YIQ format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToYiq(SDVariable input) {
SDValidation.validateNumerical("rgbToYiq", "input", input);
return new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable();
}
/**
* Converting array from RGB to YIQ format <br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToYiq(String name, SDVariable input) {
SDValidation.validateNumerical("rgbToYiq", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYiq(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Converting array from RGB to YUV format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToYuv(SDVariable input) {
SDValidation.validateNumerical("rgbToYuv", "input", input);
return new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable();
}
/**
* Converting array from RGB to YUV format <br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable rgbToYuv(String name, SDVariable input) {
SDValidation.validateNumerical("rgbToYuv", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.RgbToYuv(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Converting image from YIQ to RGB format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable yiqToRgb(SDVariable input) {
SDValidation.validateNumerical("yiqToRgb", "input", input);
return new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable();
}
/**
* Converting image from YIQ to RGB format <br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable yiqToRgb(String name, SDVariable input) {
SDValidation.validateNumerical("yiqToRgb", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.YiqToRgb(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Converting image from YUV to RGB format <br>
*
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable yuvToRgb(SDVariable input) {
SDValidation.validateNumerical("yuvToRgb", "input", input);
return new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable();
}
/**
* Converting image from YUV to RGB format <br>
*
* @param name name May be null. Name for the output variable
* @param input 3D image (NUMERIC type)
* @return output 3D image (NUMERIC type)
*/
public SDVariable yuvToRgb(String name, SDVariable input) {
SDValidation.validateNumerical("yuvToRgb", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.YuvToRgb(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
} }

View File

@ -0,0 +1,561 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
public class SDLinalg extends SDOps {
public SDLinalg(SameDiff sameDiff) {
super(sameDiff);
}
/**
* Computes the Cholesky decomposition of one or more square matrices.<br>
*
* @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type)
* @return output Transformed tensor (NUMERIC type)
*/
public SDVariable cholesky(SDVariable input) {
SDValidation.validateNumerical("Cholesky", "input", input);
return new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable();
}
/**
* Computes the Cholesky decomposition of one or more square matrices.<br>
*
* @param name name May be null. Name for the output variable
* @param input Input tensor with inner-most 2 dimensions forming square matrices (NUMERIC type)
* @return output Transformed tensor (NUMERIC type)
*/
public SDVariable cholesky(String name, SDVariable input) {
SDValidation.validateNumerical("Cholesky", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Cholesky(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Solver for linear squares problems.<br>
*
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param l2_reguralizer regularizer
* @param fast fast mode, defaults to True
* @return output Transformed tensor (FLOATING_POINT type)
*/
public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer, boolean fast) {
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable();
}
/**
* Solver for linear squares problems.<br>
*
* @param name name May be null. Name for the output variable
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param l2_reguralizer regularizer
* @param fast fast mode, defaults to True
* @return output Transformed tensor (FLOATING_POINT type)
*/
public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer,
boolean fast) {
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, fast).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Solver for linear squares problems.<br>
*
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param l2_reguralizer regularizer
* @return output Transformed tensor (FLOATING_POINT type)
*/
public SDVariable lstsq(SDVariable matrix, SDVariable rhs, double l2_reguralizer) {
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
return new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable();
}
/**
* Solver for linear squares problems.<br>
*
* @param name name May be null. Name for the output variable
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param l2_reguralizer regularizer
* @return output Transformed tensor (FLOATING_POINT type)
*/
public SDVariable lstsq(String name, SDVariable matrix, SDVariable rhs, double l2_reguralizer) {
SDValidation.validateNumerical("Lstsq", "matrix", matrix);
SDValidation.validateNumerical("Lstsq", "rhs", rhs);
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lstsq(sd,matrix, rhs, l2_reguralizer, true).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Computes LU decomposition.<br>
*
* @param input input tensor (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable lu(SDVariable input) {
SDValidation.validateNumerical("Lu", "input", input);
return new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable();
}
/**
* Computes LU decomposition.<br>
*
* @param name name May be null. Name for the output variable
* @param input input tensor (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable lu(String name, SDVariable input) {
SDValidation.validateNumerical("Lu", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.Lu(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Performs matrix mutiplication on input tensors.<br>
*
* @param a input tensor (NUMERIC type)
* @param b input tensor (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable matmul(SDVariable a, SDVariable b) {
SDValidation.validateNumerical("Matmul", "a", a);
SDValidation.validateNumerical("Matmul", "b", b);
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable();
}
/**
* Performs matrix mutiplication on input tensors.<br>
*
* @param name name May be null. Name for the output variable
* @param a input tensor (NUMERIC type)
* @param b input tensor (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable matmul(String name, SDVariable a, SDVariable b) {
SDValidation.validateNumerical("Matmul", "a", a);
SDValidation.validateNumerical("Matmul", "b", b);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,a, b).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Copy a tensor setting outside a central band in each innermost matrix.<br>
*
* @param input input tensor (NUMERIC type)
* @param minLower lower diagonal count
* @param maxUpper upper diagonal count
*/
public SDVariable[] matrixBandPart(SDVariable input, int minLower, int maxUpper) {
SDValidation.validateNumerical("MatrixBandPart", "input", input);
return new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables();
}
/**
* Copy a tensor setting outside a central band in each innermost matrix.<br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param input input tensor (NUMERIC type)
* @param minLower lower diagonal count
* @param maxUpper upper diagonal count
*/
public SDVariable[] matrixBandPart(String[] names, SDVariable input, int minLower, int maxUpper) {
SDValidation.validateNumerical("MatrixBandPart", "input", input);
SDVariable[] out = new org.nd4j.linalg.api.ops.custom.MatrixBandPart(sd,input, minLower, maxUpper).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/**
* Computes the QR decompositions of input matrix.<br>
*
* @param input input tensor (NUMERIC type)
* @param full full matrices mode
*/
public SDVariable[] qr(SDVariable input, boolean full) {
SDValidation.validateNumerical("Qr", "input", input);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables();
}
/**
* Computes the QR decompositions of input matrix.<br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param input input tensor (NUMERIC type)
* @param full full matrices mode
*/
public SDVariable[] qr(String[] names, SDVariable input, boolean full) {
SDValidation.validateNumerical("Qr", "input", input);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, full).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/**
* Computes the QR decompositions of input matrix.<br>
*
* @param input input tensor (NUMERIC type)
*/
public SDVariable[] qr(SDVariable input) {
SDValidation.validateNumerical("Qr", "input", input);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables();
}
/**
* Computes the QR decompositions of input matrix.<br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param input input tensor (NUMERIC type)
*/
public SDVariable[] qr(String[] names, SDVariable input) {
SDValidation.validateNumerical("Qr", "input", input);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Qr(sd,input, false).outputVariables();
return sd.updateVariableNamesAndReferences(out, names);
}
/**
* Solver for systems of linear equations.<br>
*
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param adjoint adjoint mode, defaults to False
* @return output Output tensor (FLOATING_POINT type)
*/
public SDVariable solve(SDVariable matrix, SDVariable rhs, boolean adjoint) {
SDValidation.validateNumerical("Solve", "matrix", matrix);
SDValidation.validateNumerical("Solve", "rhs", rhs);
return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable();
}
/**
* Solver for systems of linear equations.<br>
*
* @param name name May be null. Name for the output variable
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param adjoint adjoint mode, defaults to False
* @return output Output tensor (FLOATING_POINT type)
*/
public SDVariable solve(String name, SDVariable matrix, SDVariable rhs, boolean adjoint) {
SDValidation.validateNumerical("Solve", "matrix", matrix);
SDValidation.validateNumerical("Solve", "rhs", rhs);
SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, adjoint).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Solver for systems of linear equations.<br>
*
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @return output Output tensor (FLOATING_POINT type)
*/
public SDVariable solve(SDVariable matrix, SDVariable rhs) {
SDValidation.validateNumerical("Solve", "matrix", matrix);
SDValidation.validateNumerical("Solve", "rhs", rhs);
return new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable();
}
/**
* Solver for systems of linear equations.<br>
*
* @param name name May be null. Name for the output variable
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @return output Output tensor (FLOATING_POINT type)
*/
public SDVariable solve(String name, SDVariable matrix, SDVariable rhs) {
SDValidation.validateNumerical("Solve", "matrix", matrix);
SDValidation.validateNumerical("Solve", "rhs", rhs);
SDVariable out = new org.nd4j.linalg.api.ops.custom.LinearSolve(sd,matrix, rhs, false).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Solver for systems of linear questions.<br>
*
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param lower defines whether innermost matrices in matrix are lower or upper triangular
* @param adjoint adjoint mode
* @return output (FLOATING_POINT type)
*/
public SDVariable triangularSolve(SDVariable matrix, SDVariable rhs, boolean lower,
boolean adjoint) {
SDValidation.validateNumerical("TriangularSolve", "matrix", matrix);
SDValidation.validateNumerical("TriangularSolve", "rhs", rhs);
return new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable();
}
/**
* Solver for systems of linear questions.<br>
*
* @param name name May be null. Name for the output variable
* @param matrix input tensor (NUMERIC type)
* @param rhs input tensor (NUMERIC type)
* @param lower defines whether innermost matrices in matrix are lower or upper triangular
* @param adjoint adjoint mode
* @return output (FLOATING_POINT type)
*/
public SDVariable triangularSolve(String name, SDVariable matrix, SDVariable rhs, boolean lower,
boolean adjoint) {
SDValidation.validateNumerical("TriangularSolve", "matrix", matrix);
SDValidation.validateNumerical("TriangularSolve", "rhs", rhs);
SDVariable out = new org.nd4j.linalg.api.ops.custom.TriangularSolve(sd,matrix, rhs, lower, adjoint).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Computes pairwise cross product.<br>
*
* @param a (NUMERIC type)
* @param b (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable cross(SDVariable a, SDVariable b) {
SDValidation.validateNumerical("cross", "a", a);
SDValidation.validateNumerical("cross", "b", b);
return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable();
}
/**
* Computes pairwise cross product.<br>
*
* @param name name May be null. Name for the output variable
* @param a (NUMERIC type)
* @param b (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable cross(String name, SDVariable a, SDVariable b) {
SDValidation.validateNumerical("cross", "a", a);
SDValidation.validateNumerical("cross", "b", b);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Calculates diagonal tensor.<br>
*
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable diag(SDVariable input) {
SDValidation.validateNumerical("diag", "input", input);
return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable();
}
/**
* Calculates diagonal tensor.<br>
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable diag(String name, SDVariable input) {
SDValidation.validateNumerical("diag", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Calculates diagonal tensor.<br>
*
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable diag_part(SDVariable input) {
SDValidation.validateNumerical("diag_part", "input", input);
return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable();
}
/**
* Calculates diagonal tensor.<br>
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable diag_part(String name, SDVariable input) {
SDValidation.validateNumerical("diag_part", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Calculates log of determinant.<br>
*
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable logdet(SDVariable input) {
SDValidation.validateNumerical("logdet", "input", input);
return new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable();
}
/**
* Calculates log of determinant.<br>
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @return output (FLOATING_POINT type)
*/
public SDVariable logdet(String name, SDVariable input) {
SDValidation.validateNumerical("logdet", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.custom.Logdet(sd,input).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Matrix multiplication: out = mmul(x,y)<br>
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
*
* @param x First input variable (NUMERIC type)
* @param y Second input variable (NUMERIC type)
* @param transposeX Transpose x (first argument)
* @param transposeY Transpose y (second argument)
* @param transposeZ Transpose result array
* @return output (NUMERIC type)
*/
public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
boolean transposeZ) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
}
/**
* Matrix multiplication: out = mmul(x,y)<br>
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
*
* @param name name May be null. Name for the output variable
* @param x First input variable (NUMERIC type)
* @param y Second input variable (NUMERIC type)
* @param transposeX Transpose x (first argument)
* @param transposeY Transpose y (second argument)
* @param transposeZ Transpose result array
* @return output (NUMERIC type)
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX,
boolean transposeY, boolean transposeZ) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Matrix multiplication: out = mmul(x,y)<br>
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
*
* @param x First input variable (NUMERIC type)
* @param y Second input variable (NUMERIC type)
* @return output (NUMERIC type)
*/
public SDVariable mmul(SDVariable x, SDVariable y) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable();
}
/**
* Matrix multiplication: out = mmul(x,y)<br>
* Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.<br>
*
* @param name name May be null. Name for the output variable
* @param x First input variable (NUMERIC type)
* @param y Second input variable (NUMERIC type)
* @return output (NUMERIC type)
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, false, false, false).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Calculates singular value decomposition.<br>
*
* @param input (NUMERIC type)
* @param fullUV
* @param computeUV
* @param switchNum
* @return output (FLOATING_POINT type)
*/
public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV, int switchNum) {
SDValidation.validateNumerical("svd", "input", input);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable();
}
/**
* Calculates singular value decomposition.<br>
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @param fullUV
* @param computeUV
* @param switchNum
* @return output (FLOATING_POINT type)
*/
public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV,
int switchNum) {
SDValidation.validateNumerical("svd", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, switchNum).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Calculates singular value decomposition.<br>
*
* @param input (NUMERIC type)
* @param fullUV
* @param computeUV
* @return output (FLOATING_POINT type)
*/
public SDVariable svd(SDVariable input, boolean fullUV, boolean computeUV) {
SDValidation.validateNumerical("svd", "input", input);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
}
/**
* Calculates singular value decomposition.<br>
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @param fullUV
* @param computeUV
* @return output (FLOATING_POINT type)
*/
public SDVariable svd(String name, SDVariable input, boolean fullUV, boolean computeUV) {
SDValidation.validateNumerical("svd", "input", input);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
}

View File

@ -27,17 +27,21 @@ import org.nd4j.autodiff.samediff.SameDiff;
*/ */
public abstract class SDOps { public abstract class SDOps {
protected final SameDiff sd; protected final SameDiff sd;
public SDOps(SameDiff sameDiff) { public SDOps() {
this.sd = sameDiff; sd = null;
} }
protected DifferentialFunctionFactory f() { public SDOps(SameDiff sameDiff) {
return sd.f(); this.sd = sameDiff;
} }
protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { protected DifferentialFunctionFactory f() {
return sd.updateVariableNameAndReference(varToUpdate, newVarName); return sd.f();
} }
protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
return sd.updateVariableNameAndReference(varToUpdate, newVarName);
}
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,198 +14,232 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import java.lang.String;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.*; import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.*; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import java.util.Arrays;
import java.util.List;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.GRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs; import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.LSTMLayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRUCellOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs.SRULayerOutputs;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.primitives.Pair;
/**
* SameDiff Recurrent Neural Network operations<br>
* Accessible via {@link SameDiff#rnn()}<br>
* See also {@link SDNN} (accessible via {@link SameDiff#nn()} for general neural network ops.<br>
* See also {@link SDCNN} (accessible via {@link SameDiff#cnn()} for convolutional neural network ops.<br>
*
* @author Alex Black
*/
public class SDRNN extends SDOps { public class SDRNN extends SDOps {
public SDRNN(SameDiff sameDiff) { public SDRNN(SameDiff sameDiff) {
super(sameDiff); super(sameDiff);
} }
/**
* The GRU cell. Does a single time step operation<br>
*
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
* @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/
public SDVariable gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
SDValidation.validateNumerical("gru", "x", x);
SDValidation.validateNumerical("gru", "hLast", hLast);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariable();
}
/** /**
* See {@link #gru(String, SDVariable, SDVariable, GRUWeights)}. * The GRU cell. Does a single time step operation<br>
*/ *
public GRUCellOutputs gru(@NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { * @param name name May be null. Name for the output variable
GRUCell c = new GRUCell(sd, x, hLast, weights); * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
return new GRUCellOutputs(c.outputVariables()); * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type)
} * @param GRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/
public GRUCellOutputs gru(String name, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) {
SDValidation.validateNumerical("gru", "x", x);
SDValidation.validateNumerical("gru", "hLast", hLast);
GRUCell c = new GRUCell(sd,x, hLast, GRUWeights);
return new GRUCellOutputs(c.outputVariables(name));
}
/** /**
* The GRU cell. Does a single time step operation. * The LSTM cell. Does a single time step operation.<br>
* *
* @param baseName The base name for the gru cell * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param x Input, with shape [batchSize, inSize] * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param weights The cell's weights. * @param LSTMWeights Configuration Object
* @return The cell's outputs. * @param LSTMConfiguration Configuration Object
*/ * @return output The cell's outputs (NUMERIC type)
public GRUCellOutputs gru(String baseName, @NonNull SDVariable x, @NonNull SDVariable hLast, @NonNull GRUWeights weights) { */
GRUCell c = new GRUCell(sd, x, hLast, weights); public LSTMCellOutputs lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast,
return new GRUCellOutputs(c.outputVariables(baseName)); LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
} SDValidation.validateNumerical("lstmCell", "x", x);
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
LSTMBlockCell c = new LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
return new LSTMCellOutputs(c.outputVariables());
}
/** /**
* See {@link #lstmCell(String, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)}. * The LSTM cell. Does a single time step operation.<br>
*/ *
public LSTMCellOutputs lstmCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, * @param name name May be null. Name for the output variable
LSTMWeights weights, LSTMConfiguration config){ * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); * @param cLast Previous cell state, with shape [batchSize, numUnits] (NUMERIC type)
return new LSTMCellOutputs(c.outputVariables()); * @param yLast revious cell output, with shape [batchSize, numUnits] (NUMERIC type)
} * @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The cell's outputs (NUMERIC type)
*/
public LSTMCellOutputs lstmCell(String name, SDVariable x, SDVariable cLast, SDVariable yLast,
LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmCell", "x", x);
SDValidation.validateNumerical("lstmCell", "cLast", cLast);
SDValidation.validateNumerical("lstmCell", "yLast", yLast);
LSTMBlockCell c = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell(sd,x, cLast, yLast, LSTMWeights, LSTMConfiguration);
return new LSTMCellOutputs(c.outputVariables(name));
}
/** /**
* The LSTM cell. Does a single time step operation. * The LSTM layer. Does multiple time steps.<br>
* *
* @param baseName The base name for the lstm cell * @param maxTSLength (NUMERIC type)
* @param x Input, with shape [batchSize, inSize] * @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, numUnits] * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
* @param yLast Previous cell output, with shape [batchSize, numUnits] * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
* @param weights The cell's weights. * @param LSTMWeights Configuration Object
* @param config The cell's config. * @param LSTMConfiguration Configuration Object
* @return The cell's outputs. * @return output The layer's outputs. (NUMERIC type)
*/ */
public LSTMCellOutputs lstmCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, public SDVariable lstmLayer(SDVariable maxTSLength, SDVariable x, SDVariable cLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
LSTMBlockCell c = new LSTMBlockCell(sd, x, cLast, yLast, weights, config); SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
return new LSTMCellOutputs(c.outputVariables(baseName)); SDValidation.validateNumerical("lstmLayer", "x", x);
} SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
}
/** /**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} * The LSTM layer. Does multiple time steps.<br>
*/ *
public LSTMLayerOutputs lstmLayer(@NonNull SDVariable maxTSLength, * @param name name May be null. Name for the output variable
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, * @param maxTSLength (NUMERIC type)
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ * @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config); * @param cLast Previous/initial cell state, with shape [batchSize, numUnits] (NUMERIC type)
return new LSTMLayerOutputs(c.outputVariables(), config.getDataFormat()); * @param yLast Previous/initial cell output, with shape [batchSize, numUnits] (NUMERIC type)
} * @param LSTMWeights Configuration Object
* @param LSTMConfiguration Configuration Object
* @return output The layer's outputs. (NUMERIC type)
*/
public SDVariable lstmLayer(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast,
SDVariable yLast, LSTMWeights LSTMWeights, LSTMConfiguration LSTMConfiguration) {
SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
SDValidation.validateNumerical("lstmLayer", "x", x);
SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer(sd,maxTSLength, x, cLast, yLast, LSTMWeights, LSTMConfiguration).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} * The SRU layer. Does a single time step operation.<br>
*/ *
public LSTMLayerOutputs lstmLayer(int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
return lstmLayer( * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type)
sd.scalar("lstm_max_ts_length", maxTSLength), * @param SRUWeights Configuration Object
x, cLast, yLast, weights, config); * @return output The cell's outputs.. (NUMERIC type)
} */
public SDVariable sru(SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights) {
SDValidation.validateNumerical("sru", "x", x);
SDValidation.validateNumerical("sru", "initialC", initialC);
SDValidation.validateNumerical("sru", "mask", mask);
return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable();
}
/** /**
* See {@link #lstmLayer(String, SDVariable, SDVariable, SDVariable, SDVariable, LSTMWeights, LSTMConfiguration)} * The SRU layer. Does a single time step operation.<br>
*/ *
public LSTMLayerOutputs lstmLayer(String baseName, int maxTSLength, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast, * @param name name May be null. Name for the output variable
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){ * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
if(baseName != null) { * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
return lstmLayer(baseName, * @param mask An optional dropout mask, with shape [batchSize, inSize] (NUMERIC type)
sd.scalar(sd.generateDistinctCustomVariableName(baseName + "_max_ts_length"), maxTSLength), * @param SRUWeights Configuration Object
x, cLast, yLast, weights, config); * @return output The cell's outputs.. (NUMERIC type)
} else { */
return lstmLayer(maxTSLength, x, cLast, yLast, weights, config); public SDVariable sru(String name, SDVariable x, SDVariable initialC, SDVariable mask,
} SRUWeights SRUWeights) {
} SDValidation.validateNumerical("sru", "x", x);
SDValidation.validateNumerical("sru", "initialC", initialC);
SDValidation.validateNumerical("sru", "mask", mask);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, mask, SRUWeights).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* The LSTM layer. Does multiple time steps. * The SRU layer. Does a single time step operation.<br>
* *
* Input shape depends on data format (in config):<br> * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* TNS -> [timeSteps, batchSize, inSize]<br> * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
* NST -> [batchSize, inSize, timeSteps]<br> * @param SRUWeights Configuration Object
* NTS -> [batchSize, timeSteps, inSize]<br> * @return output The cell's outputs.. (NUMERIC type)
* */
* @param baseName The base name for the lstm layer public SDVariable sru(SDVariable x, SDVariable initialC, SRUWeights SRUWeights) {
* @param x Input, with shape dependent on the data format (in config). SDValidation.validateNumerical("sru", "x", x);
* @param cLast Previous/initial cell state, with shape [batchSize, numUnits] SDValidation.validateNumerical("sru", "initialC", initialC);
* @param yLast Previous/initial cell output, with shape [batchSize, numUnits] return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable();
* @param weights The layer's weights. }
* @param config The layer's config.
* @return The layer's outputs.
*/
public LSTMLayerOutputs lstmLayer(String baseName, @NonNull SDVariable maxTSLength,
@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SDVariable yLast,
@NonNull LSTMWeights weights, @NonNull LSTMConfiguration config){
LSTMLayer c = new LSTMLayer(sd, maxTSLength, x, cLast, yLast, weights, config);
return new LSTMLayerOutputs(c.outputVariables(baseName), config.getDataFormat());
}
/** /**
* See {@link #sruCell(String, SDVariable, SDVariable, SRUWeights)}. * The SRU layer. Does a single time step operation.<br>
*/ *
public SRUCellOutputs sruCell(@NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { * @param name name May be null. Name for the output variable
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables()); * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
} * @param initialC Initial cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param SRUWeights Configuration Object
* @return output The cell's outputs.. (NUMERIC type)
*/
public SDVariable sru(String name, SDVariable x, SDVariable initialC, SRUWeights SRUWeights) {
SDValidation.validateNumerical("sru", "x", x);
SDValidation.validateNumerical("sru", "initialC", initialC);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU(sd,x, initialC, null, SRUWeights).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* The SRU cell. Does a single time step operation. * The SRU layer. Does a single time step operation.<br>
* *
* @param baseName The base name for the sru cell * @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param x Input, with shape [batchSize, inSize] * @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, inSize] * @param SRUWeights Configuration Object
* @param weights The cell's weights. * @return output The cell's outputs. (NUMERIC type)
* @return The cell's outputs. */
*/ public SDVariable sruCell(SDVariable x, SDVariable cLast, SRUWeights SRUWeights) {
public SRUCellOutputs sruCell(String baseName, @NonNull SDVariable x, @NonNull SDVariable cLast, @NonNull SRUWeights weights) { SDValidation.validateNumerical("sruCell", "x", x);
return new SRUCellOutputs(new SRUCell(sd, x, cLast, weights).outputVariables(baseName)); SDValidation.validateNumerical("sruCell", "cLast", cLast);
} return new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable();
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(x, initialC, null, weights);
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, @NonNull SRUWeights weights) {
return sru(baseName, x, initialC, null, weights);
}
/**
* See {@link #sru(String, SDVariable, SDVariable, SDVariable, SRUWeights)}
*/
public SRULayerOutputs sru(@NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables());
}
/**
* The SRU layer. Does a single time step operation.
*
* @param baseName The base name for the sru layer
* @param x Input, with shape [batchSize, inSize, timeSeriesLength]
* @param initialC Initial cell state, with shape [batchSize, inSize]
* @param mask An optional dropout mask, with shape [batchSize, inSize]
* @param weights The layer's weights.
* @return The layer's outputs.
*/
public SRULayerOutputs sru(String baseName, @NonNull SDVariable x, @NonNull SDVariable initialC, SDVariable mask, @NonNull SRUWeights weights) {
return new SRULayerOutputs(new SRU(sd, x, initialC, mask, weights).outputVariables(baseName));
}
/**
* The SRU layer. Does a single time step operation.<br>
*
* @param name name May be null. Name for the output variable
* @param x Input, with shape [batchSize, inSize] (NUMERIC type)
* @param cLast Previous cell state, with shape [batchSize, inSize] (NUMERIC type)
* @param SRUWeights Configuration Object
* @return output The cell's outputs. (NUMERIC type)
*/
public SDVariable sruCell(String name, SDVariable x, SDVariable cLast, SRUWeights SRUWeights) {
SDValidation.validateNumerical("sruCell", "x", x);
SDValidation.validateNumerical("sruCell", "cLast", cLast);
SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell(sd,x, cLast, SRUWeights).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
} }

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,324 +14,253 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
/**
* SameDiff random number generator operations<br>
* Accessible via {@link SameDiff#random()}
*
* @author Alex Black
*/
public class SDRandom extends SDOps { public class SDRandom extends SDOps {
public SDRandom(SameDiff sameDiff) {
super(sameDiff);
}
public SDRandom(SameDiff sd) { /**
super(sd); * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
} * with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
* 1-P.<br>
*
* @param p Probability of value 1
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable bernoulli(double p, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable();
}
/** /**
* @see #bernoulli(String, double, SDVariable) * Generate a new random INDArray, where values are randomly sampled according to a Bernoulli distribution,<br>
*/ * with the specified probability. Array values will have value 1 with probability P and value 0 with probability<br>
public SDVariable bernoulli(double p, SDVariable shape) { * 1-P.<br>
return bernoulli(null, p, shape); *
} * @param name name May be null. Name for the output variable
* @param p Probability of value 1
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable bernoulli(String name, double p, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution(sd,p, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability * with the specified number of trials and probability.<br>
* 1-P.<br> *
* See {@link #bernoulli(String, double, long...)} for the equivalent function where the shape is * @param nTrials Number of trials parameter for the binomial distribution
* specified as a long[] instead * @param p Probability of success for each trial
* * @param datatype Data type of the output variable
* @param name Name of the new SDVariable * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @param p Probability of value 1 * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
* @param shape Shape of the new random SDVariable, as a 1D array */
* @return New SDVariable public SDVariable binomial(int nTrials, double p, DataType datatype, long... shape) {
*/ Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
public SDVariable bernoulli(String name, double p, SDVariable shape) { return new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable();
validateInteger("bernoulli random", shape); }
SDVariable ret = f().randomBernoulli(p, shape);
return updateVariableNameAndReference(ret, name);
}
/** /**
* @see #bernoulli(String, double, long...) * Generate a new random INDArray, where values are randomly sampled according to a Binomial distribution,<br>
*/ * with the specified number of trials and probability.<br>
public SDVariable bernoulli(double p, long... shape) { *
return bernoulli(null, p, shape); * @param name name May be null. Name for the output variable
} * @param nTrials Number of trials parameter for the binomial distribution
* @param p Probability of success for each trial
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable binomial(String name, int nTrials, double p, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.BinomialDistribution(sd,nTrials, p, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Bernoulli distribution, * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
* with the specified probability. Array values will have value 1 with probability P and value 0 with probability * P(x) = lambda * exp(-lambda * x)<br>
* 1-P.<br> *
* See {@link #bernoulli(String, double, SDVariable)} for the equivalent function where the shape is * Inputs must satisfy the following constraints: <br>
* specified as a SDVarible instead * Must be positive: lambda > 0<br>
* *
* @param name Name of the new SDVariable * @param lambda lambda parameter
* @param p Probability of value 1 * @param datatype Data type of the output variable
* @param shape Shape of the new random SDVariable, as a 1D array * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return New SDVariable * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/ */
public SDVariable bernoulli(String name, double p, long... shape) { public SDVariable exponential(double lambda, DataType datatype, long... shape) {
SDVariable ret = f().randomBernoulli(p, shape); Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return updateVariableNameAndReference(ret, name); Preconditions.checkArgument(lambda > 0, "Must be positive");
} return new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable();
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, * Generate a new random INDArray, where values are randomly sampled according to a exponential distribution:<br>
* with the specified number of trials and probability. * P(x) = lambda * exp(-lambda * x)<br>
* *
* @param nTrials Number of trials parameter for the binomial distribution * Inputs must satisfy the following constraints: <br>
* @param p Probability of success for each trial * Must be positive: lambda > 0<br>
* @param shape Shape of the new random SDVariable, as a 1D array *
* @return New SDVariable * @param name name May be null. Name for the output variable
*/ * @param lambda lambda parameter
public SDVariable binomial(int nTrials, double p, long... shape) { * @param datatype Data type of the output variable
return binomial(null, nTrials, p, shape); * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
} * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable exponential(String name, double lambda, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
Preconditions.checkArgument(lambda > 0, "Must be positive");
SDVariable out = new org.nd4j.linalg.api.ops.random.custom.RandomExponential(sd,lambda, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution, * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
* with the specified number of trials and probability. * i.e., {@code log(x) ~ N(mean, stdev)}<br>
* *
* @param name Name of the new SDVariable * @param mean Mean value for the random array
* @param nTrials Number of trials parameter for the binomial distribution * @param stddev Standard deviation for the random array
* @param p Probability of success for each trial * @param datatype Data type of the output variable
* @param shape Shape of the new random SDVariable, as a 1D array * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return New SDVariable * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/ */
public SDVariable binomial(String name, int nTrials, double p, long... shape) { public SDVariable logNormal(double mean, double stddev, DataType datatype, long... shape) {
SDVariable ret = f().randomBinomial(nTrials, p, shape); Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return updateVariableNameAndReference(ret, name); return new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
} }
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: * Generate a new random INDArray, where values are randomly sampled according to a Log Normal distribution,<br>
* P(x) = lambda * exp(-lambda * x) * i.e., {@code log(x) ~ N(mean, stdev)}<br>
* *
* @param lambda Must be > 0 * @param name name May be null. Name for the output variable
* @param shape Shape of the output * @param mean Mean value for the random array
* @return new SDVariable * @param stddev Standard deviation for the random array
*/ * @param datatype Data type of the output variable
public SDVariable exponential(double lambda, SDVariable shape) { * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
return exponential(null, lambda, shape); * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
} */
public SDVariable logNormal(String name, double mean, double stddev, DataType datatype,
long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a exponential distribution: * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
* P(x) = lambda * exp(-lambda * x) * N(mean, stdev)<br>
* *
* @param name Name of the output variable * @param mean Mean value for the random array
* @param lambda Must be > 0 * @param stddev Standard deviation for the random array
* @param shape Shape of the new variable * @param datatype Data type of the output variable
* @return new SDVaribale * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
*/ * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
public SDVariable exponential(String name, double lambda, SDVariable shape) { */
validateInteger("exponential random", shape); public SDVariable normal(double mean, double stddev, DataType datatype, long... shape) {
SDVariable ret = f().randomExponential(lambda, shape); Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return updateVariableNameAndReference(ret, name); return new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable();
} }
/** /**
* @see #logNormal(String, double, double, long...) * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
*/ * N(mean, stdev)<br>
public SDVariable logNormal(double mean, double stddev, long... shape) { *
return logNormal(null, mean, stddev, shape); * @param name name May be null. Name for the output variable
} * @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable normal(String name, double mean, double stddev, DataType datatype,
long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.GaussianDistribution(sd,mean, stddev, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Log Normal distribution, * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
* i.e., {@code log(x) ~ N(mean, stdev)}<br> * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
* *
* @param name Name of the new SDVariable * @param mean Mean value for the random array
* @param mean Mean value for the random array * @param stddev Standard deviation for the random array
* @param stddev Standard deviation for the random array * @param datatype Data type of the output variable
* @param shape Shape of the new random SDVariable * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return New SDVariable * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/ */
public SDVariable logNormal(String name, double mean, double stddev, long... shape) { public SDVariable normalTruncated(double mean, double stddev, DataType datatype, long... shape) {
SDVariable ret = f().randomLogNormal(mean, stddev, shape); Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
return updateVariableNameAndReference(ret, name); return new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
} }
/** /**
* @see #normal(String, double, double, SDVariable) * Generate a new random INDArray, where values are randomly sampled according to a Gaussian (normal) distribution,<br>
*/ * N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
public SDVariable normal(double mean, double stddev, SDVariable shape) { *
return normal(null, mean, stddev, shape); * @param name name May be null. Name for the output variable
} * @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable normalTruncated(String name, double mean, double stddev, DataType datatype,
long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution(sd,mean, stddev, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
/** /**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution, * Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
* N(mean, stdev)<br> * U(min,max)<br>
* See {@link #normal(String, double, double, long...)} for the equivalent function where the shape is *
* specified as a long[] instead * @param min Minimum value
* * @param max Maximum value.
* @param name Name of the new SDVariable * @param datatype Data type of the output variable
* @param mean Mean value for the random array * @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @param stddev Standard deviation for the random array * @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
* @param shape Shape of the new random SDVariable, as a 1D array */
* @return New SDVariable public SDVariable uniform(double min, double max, DataType datatype, long... shape) {
*/ Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
public SDVariable normal(String name, double mean, double stddev, SDVariable shape) { return new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable();
validateInteger("normal (Gaussian) random", shape); }
SDVariable ret = f().randomNormal(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #normal(String, double, double, long...)
*/
public SDVariable normal(double mean, double stddev, long... shape) {
return normal(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
* N(mean, stdev)<br>
* See {@link #normal(String, double, double, SDVariable)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable normal(String name, double mean, double stddev, long... shape) {
SDVariable ret = f().randomNormal(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #normalTruncated(String, double, double, long...)
*/
public SDVariable normalTruncated(double mean, double stddev, long... shape) {
return normalTruncated(null, mean, stddev, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Gaussian (normal) distribution,
* N(mean, stdev). However, any values more than 1 standard deviation from the mean are dropped and re-sampled<br>
*
* @param name Name of the new SDVariable
* @param mean Mean value for the random array
* @param stddev Standard deviation for the random array
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable normalTruncated(String name, double mean, double stddev, long... shape) {
SDVariable ret = f().randomNormalTruncated(mean, stddev, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #uniform(String, double, double, SDVariable)
*/
public SDVariable uniform(double min, double max, SDVariable shape) {
return uniform(null, min, max, shape);
}
/**
* @see #uniform(String, double, double, SDVariable)
*/
public SDVariable uniform(double min, double max, SDVariable shape, DataType dataType) {
return uniform(null, min, max, shape, dataType);
}
/**
* As per {@link #uniform(double, double, SDVariable, DataType)} but with Float32 output
*/
public SDVariable uniform(String name, double min, double max, SDVariable shape) {
return uniform(name, min, max, shape, null);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max). Note that the output datatype may optionally be specified. If not specified (null) - float32 output is returned<br>
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
* specified as a long[] instead
*
* @param name Name of the new SDVariable
* @param min Minimum value
* @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable, as a 1D array
* @param dataType Data type of the output array (if null: Float32 output is returned)
* @return New SDVariable, of the specified data type
*/
public SDVariable uniform(String name, double min, double max, SDVariable shape, DataType dataType) {
validateInteger("uniform random", shape);
SDVariable ret = f().randomUniform(min, max, shape, dataType);
return updateVariableNameAndReference(ret, name);
}
/**
* @see #uniform(String, double, double, long...)
*/
public SDVariable uniform(double min, double max, long... shape) {
return uniform(null, min, max, shape);
}
/**
* Generate a new random SDVariable, where values are randomly sampled according to a uniform distribution,
* U(min,max)<br>
* See {@link #uniform(double, double, long...)} for the equivalent function where the shape is
* specified as a SDVariable instead
*
* @param name Name of the new SDVariable
* @param min Minimum value
* @param max Maximum value. Must satisfy max >= min
* @param shape Shape of the new random SDVariable
* @return New SDVariable
*/
public SDVariable uniform(String name, double min, double max, long... shape) {
SDVariable ret = f().randomUniform(min, max, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable with Gamma distribution
*
* @param name Name of the output variable
* @param alpha distribution parameter
* @param beta distribution parameter
* @param shape Shape of the new variable
* @return new SDVariable
*/
public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) {
SDVariable ret = f().randomGamma(alpha, beta, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable with Poission distribution
*
* @param name Name of the output variable
* @param lambda rate distribution parameter
* @param shape Shape of the new variable
* @return new SDVariable
*/
public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) {
SDVariable ret = f().randomPoisson(shape, lambda, seeds);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable by random shuffle
*
* @param name Name of the output variable
* @param value array to shuffle
* @return new SDVariable
*/
public SDVariable shuffle(String name, SDVariable value, int... seeds) {
SDVariable ret = f().randomShuffle(value, seeds);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random INDArray, where values are randomly sampled according to a uniform distribution,<br>
* U(min,max)<br>
*
* @param name name May be null. Name for the output variable
* @param min Minimum value
* @param max Maximum value.
* @param datatype Data type of the output variable
* @param shape Shape of the new random INDArray, as a 1D array (Size: AtLeast(min=0))
* @return output Tensor with the given shape where values are randomly sampled according to a %OP_NAME% distribution (NUMERIC type)
*/
public SDVariable uniform(String name, double min, double max, DataType datatype, long... shape) {
Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length);
SDVariable out = new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(sd,min, max, datatype, shape).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
} }

View File

@ -55,6 +55,15 @@ public class SDValidation {
v.name() + "\" with non-integer data type " + v.dataType()); v.name() + "\" with non-integer data type " + v.dataType());
} }
protected static void validateNumerical(String opName, String inputName, SDVariable[] vars) {
for (SDVariable v : vars) {
if (v == null) continue;
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
v.name() + "\" with non-integer data type " + v.dataType());
}
}
/** /**
* Validate that the operation is being applied on numerical SDVariables (not boolean or utf8). * Validate that the operation is being applied on numerical SDVariables (not boolean or utf8).
* Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays
@ -97,6 +106,16 @@ public class SDValidation {
v.name() + "\" with non-integer data type " + v.dataType()); v.name() + "\" with non-integer data type " + v.dataType());
} }
protected static void validateInteger(String opName, String inputName, SDVariable[] vars) {
for (SDVariable v : vars) {
if (v == null)
return;
if (!v.dataType().isIntType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
v.name() + "\" with non-integer data type " + v.dataType());
}
}
/** /**
* Validate that the operation is being applied on an floating point type SDVariable * Validate that the operation is being applied on an floating point type SDVariable
* *
@ -200,4 +219,18 @@ public class SDValidation {
} }
} }
public static boolean isSameType(SDVariable x, SDVariable y) {
return x.dataType() == y.dataType();
}
public static boolean isSameType(SDVariable[] x) {
DataType firstDataType = x[0].dataType();
if (x.length > 1) {
for (int i = 1; i < x.length; ++i) {
if (firstDataType != x[i].dataType())
return false;
}
}
return true;
}
} }

View File

@ -16,7 +16,7 @@
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== //================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.enums; package org.nd4j.enums;
/** /**
* Data format: "NCHW" or "NHWC" */ * Data format: "NCHW" or "NHWC" */

View File

@ -633,7 +633,9 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.Lu.class, org.nd4j.linalg.api.ops.custom.Lu.class,
org.nd4j.linalg.api.ops.custom.TriangularSolve.class, org.nd4j.linalg.api.ops.custom.TriangularSolve.class,
org.nd4j.linalg.api.ops.custom.LinearSolve.class, org.nd4j.linalg.api.ops.custom.LinearSolve.class,
org.nd4j.linalg.api.ops.custom.Lstsq.class org.nd4j.linalg.api.ops.custom.Lstsq.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Qr.class,
org.nd4j.linalg.api.ops.custom.Logdet.class
); );
static { static {

View File

@ -85,6 +85,12 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
this(x, null, dimensions); this(x, null, dimensions);
} }
public BaseIndexAccumulation(INDArray x, boolean keepDims, int[] dimensions) {
this(x, null, dimensions);
this.keepDims = keepDims;
defineDimensions(dimensions);
}
public BaseIndexAccumulation(INDArray x, INDArray z, int[] dimensions) { public BaseIndexAccumulation(INDArray x, INDArray z, int[] dimensions) {
super(x, z); super(x, z);
defineDimensions(dimensions); defineDimensions(dimensions);

View File

@ -29,12 +29,17 @@ public class AdjustContrast extends BaseAdjustContrast {
super(in, factor, out); super(in, factor, out);
} }
public AdjustContrast(@NonNull INDArray in, double factor) {
this(in, factor, null);
}
public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) { public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable factor) {
super(sameDiff,new SDVariable[]{in,factor}); super(sameDiff,new SDVariable[]{in,factor});
} }
public AdjustContrast(@NonNull INDArray in, double factor) { public AdjustContrast(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
this(in, factor, null); super(sameDiff,new SDVariable[]{in});
addTArgument(factor);
} }
@Override @Override

View File

@ -50,6 +50,11 @@ public class AdjustHue extends DynamicCustomOp {
super(sameDiff,new SDVariable[]{in,factor}); super(sameDiff,new SDVariable[]{in,factor});
} }
public AdjustHue(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
super(sameDiff,new SDVariable[]{in});
addTArgument(factor);
}
@Override @Override
public String opName() { public String opName() {
return "adjust_hue"; return "adjust_hue";

View File

@ -49,6 +49,11 @@ public class AdjustSaturation extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{in, factor}); super(sameDiff, new SDVariable[]{in, factor});
} }
public AdjustSaturation(@NonNull SameDiff sameDiff, @NonNull SDVariable in, double factor) {
super(sameDiff, new SDVariable[]{in});
addTArgument(factor);
}
@Override @Override
public String opName() { public String opName() {
return "adjust_saturation"; return "adjust_saturation";

View File

@ -0,0 +1,52 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class Logdet extends DynamicCustomOp {
public Logdet(INDArray input) {
addInputArgument(input);
}
public Logdet(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "logdet";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -17,9 +17,17 @@ package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor @NoArgsConstructor
public class Lstsq extends DynamicCustomOp { public class Lstsq extends DynamicCustomOp {
@ -33,8 +41,21 @@ public class Lstsq extends DynamicCustomOp {
this(matrix, rhs, 0.0, true); this(matrix, rhs, 0.0, true);
} }
public Lstsq(@NonNull SameDiff sameDiff, @NonNull SDVariable matrix, @NonNull SDVariable rhs, double l2_regularizer, boolean fast) {
super(sameDiff, new SDVariable[]{matrix,rhs});
addTArgument(l2_regularizer);
addBArgument(fast);
}
@Override @Override
public String opName() { public String opName() {
return "lstsq"; return "lstsq";
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
} }

View File

@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -26,10 +27,9 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@NoArgsConstructor
public class MatrixBandPart extends DynamicCustomOp { public class MatrixBandPart extends DynamicCustomOp {
public MatrixBandPart() {}
public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) { public MatrixBandPart(@NonNull INDArray input, int minLower, int maxUpper) {
Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher"); Preconditions.checkArgument(input.rank() >= 2, "MatrixBandPart: Input rank should be 2 or higher");
long N = input.size(-2); long N = input.size(-2);

View File

@ -1,6 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -37,7 +36,6 @@ import java.util.*;
*/ */
@NoArgsConstructor @NoArgsConstructor
public class CropAndResize extends DynamicCustomOp { public class CropAndResize extends DynamicCustomOp {
public enum Method {BILINEAR, NEAREST}; public enum Method {BILINEAR, NEAREST};
protected Method method = Method.BILINEAR; protected Method method = Method.BILINEAR;
protected double extrapolationValue = 0.0; protected double extrapolationValue = 0.0;
@ -50,6 +48,10 @@ public class CropAndResize extends DynamicCustomOp {
addArgs(); addArgs();
} }
public CropAndResize(@NonNull SameDiff sameDiff, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
SDVariable cropOutSize, double extrapolationValue) {
this(sameDiff, image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue);
}
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue, @NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue,
@ -65,12 +67,10 @@ public class CropAndResize extends DynamicCustomOp {
outputArguments.add(output); outputArguments.add(output);
} }
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices, public CropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, INDArray cropOutSize, double extrapolationValue ) {
@NonNull INDArray cropOutSize, double extrapolationValue) { this(image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue, null);
this(image, cropBoxes, boxIndices, cropOutSize, Method.BILINEAR, extrapolationValue, null);
} }
@Override @Override
public String opName() { public String opName() {
return "crop_and_resize"; return "crop_and_resize";

View File

@ -46,6 +46,12 @@ public class ExtractImagePatches extends DynamicCustomOp {
public ExtractImagePatches(){ } public ExtractImagePatches(){ }
public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input,
int kH, int kW, int sH, int sW, int rH, int rW,
boolean sameMode) {
this(samediff, input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH,rW}, sameMode);
}
public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, @NonNull int[] kSizes, public ExtractImagePatches(@NonNull SameDiff samediff, @NonNull SDVariable input, @NonNull int[] kSizes,
@NonNull int[] strides, @NonNull int[] rates, boolean sameMode){ @NonNull int[] strides, @NonNull int[] rates, boolean sameMode){
super(samediff, input); super(samediff, input);
@ -72,16 +78,8 @@ public class ExtractImagePatches extends DynamicCustomOp {
addArgs(); addArgs();
} }
public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) { public ExtractImagePatches(INDArray input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode) {
super(new INDArray[]{input},null); this(input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode);
int[] kSises = {kH,kW};
int[] strides = {sH,sW};
int[] rates = {rH, rW};
this.kSizes = kSises;
this.strides = strides;
this.rates = rates;
this.isSameMode = sameMode;
addArgs();
} }

View File

@ -42,6 +42,13 @@ public class NonMaxSuppression extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false); super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
} }
public NonMaxSuppression(SameDiff sameDiff, SDVariable boxes, SDVariable scores, int maxOutSize,
double iouThreshold, double scoreThreshold) {
super(null, sameDiff, new SDVariable[]{boxes, scores}, false);
addIArgument(maxOutSize);
addTArgument(iouThreshold, scoreThreshold);
}
public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) { public NonMaxSuppression(INDArray boxes, INDArray scores, int maxOutSize, double iouThreshold, double scoreThreshold) {
addInputArgument(boxes,scores); addInputArgument(boxes,scores);
addIArgument(maxOutSize); addIArgument(maxOutSize);

View File

@ -54,10 +54,18 @@ public class FirstIndex extends BaseIndexAccumulation {
this.extraArgs = new Object[] {compare, eps, (double) mode}; this.extraArgs = new Object[] {compare, eps, (double) mode};
} }
public FirstIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) {
this(sameDiff, i_v, condition, keepDims, dimensions);
}
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) { public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
this(x, condition, false, dimension); this(x, condition, false, dimension);
} }
public FirstIndex(INDArray x, boolean keepDims, @NonNull Condition condition, int... dimension) {
this(x,condition,keepDims,dimension);
}
public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) { public FirstIndex(INDArray x, @NonNull Condition condition, boolean keepDims, int... dimension) {
this(x, condition, Nd4j.EPS_THRESHOLD, dimension); this(x, condition, Nd4j.EPS_THRESHOLD, dimension);
this.keepDims = keepDims; this.keepDims = keepDims;
@ -72,7 +80,6 @@ public class FirstIndex extends BaseIndexAccumulation {
this.extraArgs = new Object[] {compare, eps, (double) mode}; this.extraArgs = new Object[] {compare, eps, (double) mode};
} }
@Override @Override
public int opNum() { public int opNum() {
return 4; return 4;

View File

@ -45,6 +45,11 @@ public class IMax extends BaseIndexAccumulation {
super(x, z, dimensions); super(x, z, dimensions);
} }
public IMax(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public IMax(INDArray x, int... dimensions) { public IMax(INDArray x, int... dimensions) {
super(x, null, dimensions); super(x, null, dimensions);
} }

View File

@ -44,6 +44,10 @@ public class IMin extends BaseIndexAccumulation {
super(x, dimensions); super(x, dimensions);
} }
public IMin(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public IMin(INDArray x, INDArray z, int... dimensions) { public IMin(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions); super(x, z, dimensions);
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -38,12 +39,16 @@ import java.util.Map;
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Data @Data
@NoArgsConstructor
public class LastIndex extends BaseIndexAccumulation { public class LastIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;
protected double eps; protected double eps;
protected int mode; protected int mode;
public LastIndex(SameDiff sameDiff, SDVariable i_v, boolean keepDims, Condition condition, int... dimensions) {
this(sameDiff, i_v, condition, keepDims, dimensions);
}
public LastIndex(SameDiff sameDiff, SDVariable i_v, Condition condition, boolean keepDims, int... dimensions) { public LastIndex(SameDiff sameDiff, SDVariable i_v, Condition condition, boolean keepDims, int... dimensions) {
super(sameDiff, i_v, keepDims, dimensions); super(sameDiff, i_v, keepDims, dimensions);
this.condition = condition; this.condition = condition;
@ -53,13 +58,19 @@ public class LastIndex extends BaseIndexAccumulation {
this.extraArgs = new Object[] {compare, eps, (double) mode}; this.extraArgs = new Object[] {compare, eps, (double) mode};
} }
public LastIndex() {} public LastIndex(SameDiff sameDiff, SDVariable x, @NonNull Condition condition, int... dimensions) {
super(sameDiff, x, false, dimensions);
this.condition = condition;
}
public LastIndex(INDArray x, @NonNull Condition condition, int... dimensions) { public LastIndex(INDArray x, @NonNull Condition condition, int... dimensions) {
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
} }
public LastIndex(INDArray in, boolean keepDim, Condition condition, int... dimensions) {
this(in, condition, keepDim, dimensions);
}
public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) { public LastIndex(INDArray x, @NonNull Condition condition, boolean keepDim, int... dimensions) {
this(x, condition, Nd4j.EPS_THRESHOLD, dimensions); this(x, condition, Nd4j.EPS_THRESHOLD, dimensions);
this.keepDims = keepDim; this.keepDims = keepDim;

View File

@ -47,10 +47,6 @@ public class AvgPooling3D extends Pooling3D {
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG); super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.AVG);
} }
public AvgPooling3D(SameDiff sameDiff,INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) {
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.AVG);
}
public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) { public AvgPooling3D(@NonNull INDArray input, Pooling3DConfig pooling3DConfig) {
super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG); super(null,null,new INDArray[]{input},null,false, pooling3DConfig, Pooling3DType.AVG);
} }

View File

@ -76,6 +76,19 @@ public class BatchNorm extends DynamicCustomOp {
addArgs(); addArgs();
} }
public BatchNorm(SameDiff sameDiff, SDVariable input, SDVariable mean, SDVariable variance,
SDVariable gamma, SDVariable beta, double epsilon, int[] axis) {
super(null,sameDiff, wrapFilterNull(input, mean, variance, gamma, beta), false);
Preconditions.checkState(axis != null && axis.length > 0, "Invalid axis argument: axis must be specified" +
"and length > 0. Got %s", axis);
this.sameDiff = sameDiff;
this.applyBeta = beta != null;
this.applyGamma = gamma != null;
this.epsilon = epsilon;
this.jaxis = axis;
addArgs();
}
public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){ public BatchNorm(INDArray input, INDArray mean, INDArray variance, INDArray gamma, INDArray beta, double epsilon, int... axis){
super(wrapFilterNull(input, mean, variance, gamma, beta), null); super(wrapFilterNull(input, mean, variance, gamma, beta), null);
this.jaxis = axis; this.jaxis = axis;

View File

@ -46,6 +46,10 @@ public class Conv1D extends DynamicCustomOp {
protected Conv1DConfig config; protected Conv1DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv1D configuration : s = %s p = %s ";
public Conv1D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv1DConfig conv1DConfig) {
this(sameDiff, wrapFilterNull(input, weights, bias), conv1DConfig);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv1D(SameDiff sameDiff, public Conv1D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
@ -64,12 +68,8 @@ public class Conv1D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv1D( @NonNull INDArray input, @NonNull INDArray weights, INDArray bias, Conv1DConfig conv1DConfig) { public Conv1D(INDArray input, INDArray weights, INDArray bias, Conv1DConfig config) {
this(wrapFilterNull(input, weights, bias), null, conv1DConfig); this(input, weights, bias, null, config);
}
public Conv1D(@NonNull INDArray input, @NonNull INDArray weights, Conv1DConfig conv1DConfig) {
this(new INDArray[]{input, weights}, null, conv1DConfig);
} }
private void initConfig(Conv1DConfig config){ private void initConfig(Conv1DConfig config){

View File

@ -56,6 +56,11 @@ public class Conv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s "; private static final String INVALID_CONFIGURATION = "Invalid Conv2D configuration : sW = %s pH = %s dW = %s ";
public Conv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv2D(SameDiff sameDiff, public Conv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
@ -75,12 +80,8 @@ public class Conv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, @NonNull Conv2DConfig conv2DConfig) { public Conv2D(INDArray layerInput, INDArray weights, INDArray bias, Conv2DConfig config) {
this(new INDArray[]{layerInput, weights}, null, conv2DConfig); this(layerInput, weights, bias, null, config);
}
public Conv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, INDArray bias, @NonNull Conv2DConfig conv2DConfig) {
this(wrapFilterNull(layerInput, weights,bias), null, conv2DConfig);
} }
protected void initConfig(Conv2DConfig config){ protected void initConfig(Conv2DConfig config){

View File

@ -55,6 +55,11 @@ public class Conv3D extends DynamicCustomOp {
public Conv3D() { public Conv3D() {
} }
public Conv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
SDVariable bias, @NonNull Conv3DConfig config) {
this(sameDiff, wrapFilterNull(input, weights, bias), config);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) { public Conv3D(SameDiff sameDiff, SDVariable[] inputFunctions, Conv3DConfig config) {
super(sameDiff, inputFunctions); super(sameDiff, inputFunctions);
@ -70,12 +75,12 @@ public class Conv3D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public Conv3D(@NonNull INDArray input,@NonNull INDArray weights, @NonNull Conv3DConfig conv3DConfig) { public Conv3D(INDArray input, INDArray weights, INDArray bias, Conv3DConfig config) {
this(new INDArray[]{input, weights}, null, conv3DConfig); this(wrapFilterNull(input, weights, bias), null, config);
} }
public Conv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull Conv3DConfig conv3DConfig) { public Conv3D(INDArray input, INDArray weights, Conv3DConfig config) {
this(wrapFilterNull(input, weights, bias) , null, conv3DConfig); this(wrapFilterNull(input, weights), null, config);
} }
private void initConfig(Conv3DConfig config){ private void initConfig(Conv3DConfig config){

View File

@ -52,6 +52,11 @@ public class DeConv2D extends DynamicCustomOp {
protected DeConv2DConfig config; protected DeConv2DConfig config;
public DeConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights,
SDVariable bias, DeConv2DConfig config) {
this(sameDiff, wrapFilterNull(input, weights, bias), config);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public DeConv2D(SameDiff sameDiff, public DeConv2D(SameDiff sameDiff,
SDVariable[] inputs, SDVariable[] inputs,
@ -73,15 +78,10 @@ public class DeConv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DeConv2D(@NonNull INDArray layerInput, @NonNull INDArray weights, DeConv2DConfig deConv2DConfig) { public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig config) {
this(wrapFilterNull(layerInput, weights), null, deConv2DConfig); this(layerInput, weights, bias, null, config);
} }
public DeConv2D(INDArray layerInput, INDArray weights, INDArray bias, DeConv2DConfig deConv2DConfig) {
this(wrapFilterNull(layerInput, weights, bias), null, deConv2DConfig);
}
@Override @Override
public long[] iArgs() { public long[] iArgs() {
if (iArguments.size() == 0) if (iArguments.size() == 0)

View File

@ -48,12 +48,18 @@ public class DeConv3D extends DynamicCustomOp {
protected DeConv3DConfig config; protected DeConv3DConfig config;
public DeConv3D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) { public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, bias)); super(sameDiff, toArr(input, weights, bias));
this.config = config; this.config = config;
addArgs(); addArgs();
} }
public DeConv3D(SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull DeConv3DConfig config) {
super(sameDiff, toArr(input, weights, null));
this.config = config;
addArgs();
}
public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){ public DeConv3D(INDArray[] inputs, INDArray[] outputs, DeConv3DConfig config){
super(inputs, outputs); super(inputs, outputs);
@ -65,12 +71,8 @@ public class DeConv3D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, @NonNull DeConv3DConfig deConv3DConfig) { public DeConv3D(INDArray input, INDArray weights, INDArray bias, DeConv3DConfig config) {
this(new INDArray[]{input, weights}, null, deConv3DConfig); this(input, weights, bias, null, config);
}
public DeConv3D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, @NonNull DeConv3DConfig deConv3DConfig) {
this(wrapFilterNull(input, weights, bias), null, deConv3DConfig);
} }
private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){ private static SDVariable[] toArr(SDVariable input, SDVariable weights, SDVariable bias){

View File

@ -16,16 +16,15 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.DataFormat;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.enums.DataFormat;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -46,45 +45,48 @@ import java.util.*;
* @author raver119@gmail.com, Max Pumperla * @author raver119@gmail.com, Max Pumperla
*/ */
public class DepthToSpace extends DynamicCustomOp { public class DepthToSpace extends DynamicCustomOp {
private String dataFormat = "NHWC"; private DataFormat dataFormat = DataFormat.NHWC;
private int blockSize; private int blockSize;
public DepthToSpace() { public DepthToSpace() {
} }
public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { public DepthToSpace(SameDiff sameDiff, SDVariable args, int blockSize, DataFormat dataFormat) {
this(sameDiff, new SDVariable[]{args}, blockSize, dataFormat);
}
public DepthToSpace(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) {
super(null, sameDiff, args, false); super(null, sameDiff, args, false);
this.blockSize = blockSize; this.blockSize = blockSize;
this.dataFormat = dataFormat; this.dataFormat = dataFormat;
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public DepthToSpace(@NonNull INDArray in, INDArray out, int blockSize, @NonNull String dataFormat) { public DepthToSpace(INDArray in, INDArray out, int blockSize, DataFormat dataFormat) {
super(null, in, out, null, null); super(null, in, out, null, null);
this.blockSize = blockSize; this.blockSize = blockSize;
this.dataFormat = dataFormat; this.dataFormat = dataFormat;
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public DepthToSpace(@NonNull INDArray x, int blockSize, DataFormat dataFormat) { public DepthToSpace(INDArray in, int blockSize, DataFormat dataFormat) {
this(x, null, blockSize, dataFormat.toString()); this(in, null, blockSize, dataFormat);
} }
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Gradient to DepthToSpace is just SpaceToDepth of same block size and data format. // Gradient to DepthToSpace is just SpaceToDepth of same block size and data format.
SDVariable gradient = i_v.get(0); SDVariable gradient = i_v.get(0);
SDVariable ret = sameDiff.cnn().spaceToDepth(gradient, blockSize, dataFormat); SDVariable ret = new SpaceToDepth(sameDiff, new SDVariable[]{gradient}, blockSize, dataFormat).outputVariable();
return Arrays.asList(ret); return Arrays.asList(ret);
} }
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }

View File

@ -16,8 +16,11 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.*; import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -49,11 +52,15 @@ import java.util.*;
*/ */
@Slf4j @Slf4j
@Getter @Getter
@NoArgsConstructor
public class DepthwiseConv2D extends DynamicCustomOp { public class DepthwiseConv2D extends DynamicCustomOp {
protected Conv2DConfig config; protected Conv2DConfig config;
public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input,
@NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig);
}
@Builder(builderMethodName = "sameDiffBuilder") @Builder(builderMethodName = "sameDiffBuilder")
public DepthwiseConv2D(SameDiff sameDiff, public DepthwiseConv2D(SameDiff sameDiff,
SDVariable[] inputFunctions, SDVariable[] inputFunctions,
@ -75,16 +82,11 @@ public class DepthwiseConv2D extends DynamicCustomOp {
this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config);
} }
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, Conv2DConfig conv2DConfig) { public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig config) {
this(wrapFilterNull(layerInput, depthWeights), null, conv2DConfig); this(layerInput, depthWeights, bias, null, config);
} }
public DepthwiseConv2D(INDArray layerInput, INDArray depthWeights, INDArray bias, Conv2DConfig conv2DConfig) { public DepthwiseConv2D() {
this(wrapFilterNull(layerInput, depthWeights, bias), null, conv2DConfig);
}
public DepthwiseConv2D(INDArray inputs, Conv2DConfig conv2DConfig) {
this(wrapFilterNull(inputs), null, conv2DConfig);
} }
@Override @Override

View File

@ -58,6 +58,10 @@ public class LocalResponseNormalization extends DynamicCustomOp {
addArgs(); addArgs();
} }
public LocalResponseNormalization(SameDiff sameDiff, SDVariable input, LocalResponseNormalizationConfig config) {
this(sameDiff, new SDVariable[]{input}, false, config);
}
public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){ public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config){
super(new INDArray[]{input}, wrapOrNull(output)); super(new INDArray[]{input}, wrapOrNull(output));

View File

@ -60,15 +60,16 @@ public class MaxPooling2D extends DynamicCustomOp {
addArgs(); addArgs();
} }
public MaxPooling2D(@NonNull INDArray input, INDArray output, @NonNull Pooling2DConfig config){ public MaxPooling2D(INDArray input, INDArray output, @NonNull Pooling2DConfig config){
super(null, new INDArray[]{input}, wrapOrNull(output)); super(null, new INDArray[]{input}, wrapOrNull(output));
config.setType(Pooling2D.Pooling2DType.MAX); config.setType(Pooling2D.Pooling2DType.MAX);
this.config = config; this.config = config;
addArgs(); addArgs();
} }
public MaxPooling2D(@NonNull INDArray input, @NonNull Pooling2DConfig pooling2DConfig) { public MaxPooling2D(INDArray input, @NonNull Pooling2DConfig config){
this(input, null, pooling2DConfig); this(input, null, config);
} }
@Override @Override

View File

@ -47,8 +47,12 @@ public class MaxPooling3D extends Pooling3D {
super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX); super(sameDiff, new SDVariable[]{input}, null, null, false, config, Pooling3DType.MAX);
} }
public MaxPooling3D(SameDiff sameDiff, INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) { public MaxPooling3D(INDArray arrayInput, INDArray arrayOutput, Pooling3DConfig config) {
super(sameDiff, null, new INDArray[]{arrayInput}, wrapOrNull(arrayOutput), false, config, Pooling3DType.MAX); addInputArgument(arrayInput);
if (arrayOutput != null)
addOutputArgument(arrayOutput);
this.config = config;
addArgs();
} }
public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) { public MaxPooling3D(INDArray input, Pooling3DConfig pooling3DConfig) {

View File

@ -44,18 +44,23 @@ public class SConv2D extends Conv2D {
super(sameDiff, inputFunctions, conv2DConfig); super(sameDiff, inputFunctions, conv2DConfig);
} }
public SConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable layerInput, @NonNull SDVariable depthWeights,
@NonNull SDVariable pointWeights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) {
this(sameDiff, wrapFilterNull(layerInput, depthWeights, pointWeights, bias), conv2DConfig);
}
public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ public SConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){
super(inputs, outputs, config); super(inputs, outputs, config);
} }
public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, INDArray bias, @NonNull Conv2DConfig Conv2DConfig){
this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, Conv2DConfig);
}
public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){ public SConv2D(@NonNull INDArray layerInput, @NonNull INDArray depthWeights, INDArray pointWeights, @NonNull Conv2DConfig Conv2DConfig){
this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig); this(wrapFilterNull(layerInput, depthWeights, pointWeights), null, Conv2DConfig);
} }
public SConv2D(INDArray layerInput, INDArray depthWeights, INDArray pointWeights, INDArray bias, Conv2DConfig config) {
this(wrapFilterNull(layerInput, depthWeights, pointWeights, bias), null, config);
}
public SConv2D() {} public SConv2D() {}

View File

@ -16,16 +16,15 @@
package org.nd4j.linalg.api.ops.impl.layers.convolution; package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.DataFormat;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.enums.DataFormat;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
@ -45,47 +44,48 @@ import java.util.*;
* @author raver119@gmail.com, Max Pumperla * @author raver119@gmail.com, Max Pumperla
*/ */
public class SpaceToDepth extends DynamicCustomOp { public class SpaceToDepth extends DynamicCustomOp {
private String dataFormat; private DataFormat dataFormat;
private int blockSize; private int blockSize;
public SpaceToDepth() { public SpaceToDepth() {
} }
public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, String dataFormat) { public SpaceToDepth(SameDiff sameDiff, SDVariable[] args, int blockSize, DataFormat dataFormat) {
super(null, sameDiff, args, false); super(null, sameDiff, args, false);
this.blockSize = blockSize; this.blockSize = blockSize;
this.dataFormat = dataFormat; this.dataFormat = dataFormat;
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public SpaceToDepth(INDArray in, INDArray out, int blockSize, String dataFormat){ public SpaceToDepth(SameDiff sameDiff, SDVariable x, int blockSize, DataFormat dataFormat) {
this(sameDiff, new SDVariable[]{x}, blockSize, dataFormat);
}
public SpaceToDepth(INDArray in, INDArray out, int blockSize, DataFormat dataFormat){
super(null, in, out, null, null); super(null, in, out, null, null);
this.blockSize = blockSize; this.blockSize = blockSize;
this.dataFormat = dataFormat; this.dataFormat = dataFormat;
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }
public SpaceToDepth(INDArray x, int blockSize, DataFormat dataFormat) {
this(x, null, blockSize, dataFormat);
public SpaceToDepth(@NonNull INDArray x, int blockSize, @NonNull DataFormat dataFormat) {
this(x, null, blockSize,dataFormat.toString());
} }
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
// Gradient to SpaceToDepth is just DepthToSpace of same block size and data format. // Gradient to SpaceToDepth is just DepthToSpace of same block size and data format.
SDVariable gradient = i_v.get(0); SDVariable gradient = i_v.get(0);
SDVariable ret = sameDiff.cnn().depthToSpace(gradient, blockSize, dataFormat); SDVariable ret = new DepthToSpace(sameDiff, gradient, blockSize, dataFormat).outputVariable();
return Arrays.asList(ret); return Arrays.asList(ret);
} }
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); boolean isNHWC = dataFormat == null ? true : dataFormat.equals(DataFormat.NHWC);
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }

View File

@ -56,6 +56,14 @@ public class Upsampling2d extends DynamicCustomOp {
addIArgument(nchw ? 1 : 0); addIArgument(nchw ? 1 : 0);
} }
public Upsampling2d(SameDiff sameDiff, SDVariable input, int scaleH, int scaleW, boolean nchw) {
this(sameDiff, input, nchw, scaleH, scaleW);
}
public Upsampling2d(SameDiff sameDiff, SDVariable input, int scale) {
super(null,sameDiff, new SDVariable[]{input});
addIArgument(scale);
}
public Upsampling2d(INDArray input, int scale) { public Upsampling2d(INDArray input, int scale) {
this(input, scale, scale, true); this(input, scale, scale, true);

View File

@ -38,6 +38,11 @@ public class AbsoluteDifferenceLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels); super(sameDiff, lossReduce, predictions, weights, labels);
} }
public AbsoluteDifferenceLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, label);
}
public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ public AbsoluteDifferenceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
} }

View File

@ -33,9 +33,9 @@ public abstract class BaseLoss extends DynamicCustomOp {
protected LossReduce lossReduce; protected LossReduce lossReduce;
public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, @NonNull SDVariable weights, public BaseLoss(@NonNull SameDiff sameDiff, @NonNull LossReduce lossReduce, @NonNull SDVariable predictions, SDVariable weights,
@NonNull SDVariable labels){ @NonNull SDVariable labels){
super(null, sameDiff, new SDVariable[]{predictions, weights, labels}); super(null, sameDiff, new SDVariable[]{predictions, getWeights(sameDiff, weights, predictions), labels});
this.lossReduce = lossReduce; this.lossReduce = lossReduce;
addArgs(); addArgs();
} }
@ -50,6 +50,10 @@ public abstract class BaseLoss extends DynamicCustomOp {
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0); return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
} }
protected static SDVariable getWeights(SameDiff sd, SDVariable weights, SDVariable predictions){
return weights != null ? weights : sd.constant(Nd4j.scalar(predictions.dataType(), 1.0));
}
protected BaseLoss(){ } protected BaseLoss(){ }
protected void addArgs(){ protected void addArgs(){
@ -62,7 +66,7 @@ public abstract class BaseLoss extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() >= 2, "Expected exactly 2 or more input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
} }
} }

View File

@ -39,6 +39,11 @@ public class CosineDistanceLoss extends BaseLoss {
this.addIArgument(dimension); this.addIArgument(dimension);
} }
public CosineDistanceLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce, int dimension) {
this(sameDiff, lossReduce, predictions, weights, labels, dimension);
}
public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){ public CosineDistanceLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, int dimension){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
this.dimension = dimension; this.dimension = dimension;

View File

@ -36,6 +36,11 @@ public class HingeLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels); super(sameDiff, lossReduce, predictions, weights, labels);
} }
public HingeLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, labels);
}
public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ public HingeLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
} }

View File

@ -41,6 +41,11 @@ public class HuberLoss extends BaseLoss {
tArguments.add(delta); tArguments.add(delta);
} }
public HuberLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce, double delta) {
this(sameDiff, lossReduce, predictions, weights, labels, delta);
}
public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){ public HuberLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double delta){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
this.delta = delta; this.delta = delta;

View File

@ -41,6 +41,11 @@ public class LogLoss extends BaseLoss {
addTArgument(epsilon); addTArgument(epsilon);
} }
public LogLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce, double epsilon) {
this(sameDiff, lossReduce, predictions, weights, labels, epsilon);
}
public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){ public LogLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce, double epsilon){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
this.epsilon = epsilon; this.epsilon = epsilon;

View File

@ -38,6 +38,11 @@ public class LogPoissonLoss extends BaseLoss {
this(sameDiff, lossReduce, predictions, weights, labels, false); this(sameDiff, lossReduce, predictions, weights, labels, false);
} }
public LogPoissonLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce, boolean full) {
this(sameDiff, lossReduce, predictions, weights, labels, full);
}
public LogPoissonLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels, boolean full){ public LogPoissonLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels, boolean full){
super(sameDiff, lossReduce, predictions, weights, labels); super(sameDiff, lossReduce, predictions, weights, labels);
this.full = full; this.full = full;

View File

@ -34,6 +34,11 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels); super(sameDiff, lossReduce, predictions, weights, labels);
} }
public MeanPairwiseSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions,
SDVariable weights, LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, labels);
}
public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ public MeanPairwiseSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
} }

View File

@ -36,6 +36,11 @@ public class MeanSquaredErrorLoss extends BaseLoss {
super(sameDiff, lossReduce, predictions, weights, labels); super(sameDiff, lossReduce, predictions, weights, labels);
} }
public MeanSquaredErrorLoss(SameDiff sameDiff, SDVariable labels, SDVariable predictions, SDVariable weights,
LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, labels);
}
public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){ public MeanSquaredErrorLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels); super(lossReduce, predictions, weights, labels);
} }

View File

@ -44,6 +44,11 @@ public class SigmoidCrossEntropyLoss extends BaseLoss {
public static final double DEFAULT_LABEL_SMOOTHING = 0.0; public static final double DEFAULT_LABEL_SMOOTHING = 0.0;
private double labelSmoothing = 0.0; private double labelSmoothing = 0.0;
public SigmoidCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits, SDVariable weights,
LossReduce lossReduce, double labelSmoothing) {
this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing);
}
public SigmoidCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, public SigmoidCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights,
SDVariable labels, double labelSmoothing) { SDVariable labels, double labelSmoothing) {
super(sameDiff, lossReduce, logits, weights, labels); super(sameDiff, lossReduce, logits, weights, labels);

View File

@ -45,6 +45,11 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
private double labelSmoothing = 0.0; private double labelSmoothing = 0.0;
public SoftmaxCrossEntropyLoss(SameDiff sameDiff, SDVariable labels, SDVariable logits,
SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
this(sameDiff, lossReduce, logits, weights, labels, labelSmoothing);
}
public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels, public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable logits, SDVariable weights, SDVariable labels,
double labelSmoothing) { double labelSmoothing) {
super(sameDiff, lossReduce, logits, weights, labels); super(sameDiff, lossReduce, logits, weights, labels);

View File

@ -93,6 +93,24 @@ public class Mmul extends DynamicCustomOp {
} }
} }
public Mmul(INDArray x, INDArray y, boolean transposeX, boolean transposeY, boolean transposeZ) {
addInputArgument(x, y);
addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ));
}
public Mmul(INDArray x, INDArray y) {
this(x,y,null,null);
}
public Mmul(SameDiff sameDiff, SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
boolean transposeZ) {
super(null,sameDiff,new SDVariable[]{x,y});
addIArgument(ArrayUtil.fromBoolean(transposeX),
ArrayUtil.fromBoolean(transposeY),
ArrayUtil.fromBoolean(transposeZ));
}
public Mmul() {} public Mmul() {}

View File

@ -77,6 +77,18 @@ public class TensorMmul extends DynamicCustomOp {
addIArgument(dimensions[1]); addIArgument(dimensions[1]);
} }
public TensorMmul(SameDiff sameDiff, SDVariable x, SDVariable y, int[] dimensionsX,
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
super(null, sameDiff, new SDVariable[]{x,y});
this.sameDiff = sameDiff;
this.axes = new int[][]{dimensionsX, dimensionsY};
addIArgument(dimensionsX.length);
addIArgument(dimensionsX[0]);
addIArgument(dimensionsY.length);
addIArgument(dimensionsY[0]);
addBArgument(transposeX, transposeY, transposeZ);
}
@Override @Override
public List<LongShapeDescriptor> calculateOutputShape() { public List<LongShapeDescriptor> calculateOutputShape() {
List<LongShapeDescriptor> ret = new ArrayList<>(1); List<LongShapeDescriptor> ret = new ArrayList<>(1);
@ -242,6 +254,13 @@ public class TensorMmul extends DynamicCustomOp {
this.axes = axes; this.axes = axes;
} }
public TensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY,
boolean transposeX, boolean transposeY, boolean transposeZ) {
super(null,new INDArray[]{x, y},null);
this.axes = new int[][]{dimensionsX, dimensionsY};
addBArgument(transposeX, transposeY, transposeZ);
}
@Override @Override
public String opName() { public String opName() {
return "tensordot"; return "tensordot";

View File

@ -41,6 +41,10 @@ public class Any extends BaseReduceBoolOp {
super(x); super(x);
} }
public Any(INDArray x, int... dimensions) {
super(x, dimensions);
}
@Override @Override
public int opNum() { public int opNum() {
return 0; return 0;

View File

@ -45,6 +45,10 @@ public class LogSumExp extends DynamicCustomOp {
this.keepDims = keepDims; this.keepDims = keepDims;
} }
public LogSumExp(SameDiff sameDiff, SDVariable i_v, int[] dimensions) {
this(sameDiff, i_v, false, dimensions);
}
public LogSumExp() {} public LogSumExp() {}
public LogSumExp(INDArray x, int... dimensions) { public LogSumExp(INDArray x, int... dimensions) {

View File

@ -41,6 +41,10 @@ public class SquaredNorm extends BaseReduceFloatOp {
super(input, output, keepDims, dimensions); super(input, output, keepDims, dimensions);
} }
public SquaredNorm(INDArray input, boolean keepDims, int... dimensions){
this(input, null, keepDims, dimensions);
}
public SquaredNorm(){} public SquaredNorm(){}
@Override @Override

View File

@ -38,6 +38,10 @@ public class MatchCondition extends BaseReduceLongOp {
private double eps; private double eps;
private int mode; private int mode;
public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition) {
this(sameDiff, in, condition, false, null);
}
public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { public MatchCondition(SameDiff sameDiff, SDVariable in, Condition condition, boolean keepDims, int... dimensions) {
super(sameDiff, in, dimensions, keepDims); super(sameDiff, in, dimensions, keepDims);
this.compare = condition.getValue(); this.compare = condition.getValue();
@ -64,6 +68,10 @@ public class MatchCondition extends BaseReduceLongOp {
defineDimensions(dimensions); defineDimensions(dimensions);
} }
public MatchCondition(INDArray in, Condition condition, boolean keepDim, int... dimensions) {
this(in, condition, dimensions);
}
@Override @Override
public int opNum() { public int opNum() {
return 2; return 2;

View File

@ -56,6 +56,10 @@ public class Sum extends BaseReduceSameOp {
super(x, z, keepDims, dimensions); super(x, z, keepDims, dimensions);
} }
public Sum(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, keepDims, dimensions);
}
@Override @Override
public int opNum() { public int opNum() {
return 0; return 0;

View File

@ -50,6 +50,10 @@ public class LeakyReLU extends BaseScalarOp {
} }
public LeakyReLU(SameDiff sameDiff, SDVariable i_v, double alpha) {
this(sameDiff, i_v, false, alpha);
}
public LeakyReLU(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) { public LeakyReLU(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double alpha) {
super(sameDiff, i_v, alpha, extraArgs); super(sameDiff, i_v, alpha, extraArgs);
this.alpha = alpha; this.alpha = alpha;

View File

@ -42,6 +42,10 @@ public class Pow extends BaseScalarOp {
this.extraArgs = new Object[]{pow}; this.extraArgs = new Object[]{pow};
} }
public Pow(SameDiff sameDiff, SDVariable i_v, double pow) {
this(sameDiff, i_v, false, pow);
}
public Pow(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double pow) { public Pow(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, double pow) {
super(sameDiff, i_v, pow, extraArgs); super(sameDiff, i_v, pow, extraArgs);

View File

@ -35,6 +35,10 @@ public class RectifiedLinear extends BaseScalarOp {
super(sameDiff, i_v, cutoff, inPlace); super(sameDiff, i_v, cutoff, inPlace);
} }
public RectifiedLinear(SameDiff sameDiff, SDVariable i_v, double cutoff) {
this(sameDiff, i_v, false, cutoff);
}
public RectifiedLinear() { public RectifiedLinear() {
super(); super();
} }

View File

@ -42,6 +42,10 @@ public class Relu6 extends BaseScalarOp {
super(sameDiff, i_v, cutoff, inPlace); super(sameDiff, i_v, cutoff, inPlace);
} }
public Relu6(SameDiff sameDiff, SDVariable i_v, double cutoff) {
this(sameDiff, i_v, false, cutoff);
}
public Relu6() { public Relu6() {
// //
} }

View File

@ -41,6 +41,10 @@ public class Step extends BaseScalarOp {
this.extraArgs = new Object[] {cutoff}; this.extraArgs = new Object[] {cutoff};
} }
public Step(SameDiff sameDiff, SDVariable i_v, double cutoff) {
this(sameDiff, i_v, false, cutoff);
}
public Step() { public Step() {
cutoff = 0.0; cutoff = 0.0;
this.extraArgs = new Object[] {cutoff}; this.extraArgs = new Object[] {cutoff};

View File

@ -46,6 +46,9 @@ public class ScalarLessThan extends BaseScalarBoolOp {
super(sameDiff, i_v, scalar, inPlace); super(sameDiff, i_v, scalar, inPlace);
} }
public ScalarLessThan(SameDiff sameDiff, SDVariable i_v, double scalar) {
super(sameDiff, i_v, scalar, false);
}
@Override @Override
public int opNum() { public int opNum() {

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -43,6 +44,10 @@ public class ScatterAdd extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterAdd(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterAdd(){} public ScatterAdd(){}
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -43,6 +44,10 @@ public class ScatterDiv extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterDiv(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterDiv() {} public ScatterDiv() {}
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -41,6 +42,10 @@ public class ScatterMax extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMax(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMax() {} public ScatterMax() {}
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -41,6 +42,10 @@ public class ScatterMin extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMin(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMin() {} public ScatterMin() {}
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -43,6 +44,10 @@ public class ScatterMul extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterMul(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterMul() {} public ScatterMul() {}
@Override @Override

View File

@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -43,6 +44,10 @@ public class ScatterSub extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterSub(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterSub() {} public ScatterSub() {}
@Override @Override

View File

@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -53,6 +54,10 @@ public class ScatterUpdate extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false); super(null, sameDiff, new SDVariable[]{ref, indices, updates}, false);
} }
public ScatterUpdate(INDArray ref, INDArray indices, INDArray updates) {
addInputArgument(ref, indices, updates);
}
public ScatterUpdate(){} public ScatterUpdate(){}
@Override @Override

View File

@ -49,6 +49,14 @@ public class Concat extends DynamicCustomOp {
addIArgument(concatDimension); addIArgument(concatDimension);
} }
public Concat(INDArray[] arrays, int concatDimension) {
this(concatDimension, arrays);
}
public Concat(SameDiff sameDiff, SDVariable[] inputs, int concatDimension){
this(sameDiff, concatDimension, inputs);
}
public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){ public Concat(SameDiff sameDiff, int concatDimension, SDVariable... inputs){
super(null, sameDiff, inputs); super(null, sameDiff, inputs);
addIArgument(concatDimension); addIArgument(concatDimension);

View File

@ -68,6 +68,12 @@ public class ConfusionMatrix extends DynamicCustomOp {
} }
} }
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, DataType dataType){
this(sameDiff, labels, pred, weights);
this.outputType = dataType;
}
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
super(null, sameDiff, new SDVariable[]{labels, pred}); super(null, sameDiff, new SDVariable[]{labels, pred});
this.outputType = dataType; this.outputType = dataType;
@ -82,6 +88,11 @@ public class ConfusionMatrix extends DynamicCustomOp {
addIArgument(numClasses); addIArgument(numClasses);
} }
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights, Integer numClasses){
super(null, sameDiff, new SDVariable[]{labels, pred, weights});
addIArgument(numClasses);
}
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){ public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights){
super(null, sameDiff, new SDVariable[]{labels, pred, weights}); super(null, sameDiff, new SDVariable[]{labels, pred, weights});
if(numClasses != null) { if(numClasses != null) {

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -39,15 +40,17 @@ import java.util.List;
* *
* @author Max Pumperla * @author Max Pumperla
*/ */
@NoArgsConstructor
public class Cross extends DynamicCustomOp { public class Cross extends DynamicCustomOp {
public Cross() {
}
public Cross(SameDiff sameDiff, SDVariable[] args) { public Cross(SameDiff sameDiff, SDVariable[] args) {
super(null, sameDiff, args, false); super(null, sameDiff, args, false);
} }
public Cross(SameDiff sameDiff, SDVariable a, SDVariable b) {
this(sameDiff, new SDVariable[]{a,b});
}
public Cross(INDArray a, INDArray b){ public Cross(INDArray a, INDArray b){
this(a,b,null); this(a,b,null);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -39,11 +40,9 @@ import java.util.Map;
* *
* @author Max Pumperla * @author Max Pumperla
*/ */
@NoArgsConstructor
public class Diag extends DynamicCustomOp { public class Diag extends DynamicCustomOp {
public Diag() {
}
public Diag(@NonNull INDArray input) { public Diag(@NonNull INDArray input) {
this(input, null); this(input, null);
} }
@ -52,6 +51,10 @@ public class Diag extends DynamicCustomOp {
super(null, new INDArray[]{input}, wrapOrNull(output)); super(null, new INDArray[]{input}, wrapOrNull(output));
} }
public Diag(SameDiff sameDiff, SDVariable input) {
this(sameDiff, new SDVariable[]{input}, false);
}
public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { public Diag(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
super(null, sameDiff, args, inPlace); super(null, sameDiff, args, inPlace);

View File

@ -50,6 +50,10 @@ public class DiagPart extends DynamicCustomOp {
super(null, sameDiff, args, inPlace); super(null, sameDiff, args, inPlace);
} }
public DiagPart(SameDiff sameDiff, SDVariable in) {
this(sameDiff, new SDVariable[]{in}, false);
}
public DiagPart(INDArray in){ public DiagPart(INDArray in){
this(in, null); this(in, null);
} }

View File

@ -46,6 +46,10 @@ public class ExpandDims extends DynamicCustomOp {
public ExpandDims() { public ExpandDims() {
} }
public ExpandDims(SameDiff sameDiff, SDVariable args, int axis) {
this(sameDiff, new SDVariable[]{args}, axis);
}
public ExpandDims(SameDiff sameDiff, SDVariable[] args, int axis) { public ExpandDims(SameDiff sameDiff, SDVariable[] args, int axis) {
super(null, sameDiff, args); super(null, sameDiff, args);
if (axis == Integer.MAX_VALUE) { if (axis == Integer.MAX_VALUE) {
@ -63,6 +67,11 @@ public class ExpandDims extends DynamicCustomOp {
super(null, inputs, outputs); super(null, inputs, outputs);
} }
public ExpandDims(INDArray input, int axis) {
addInputArgument(input);
addIArgument(axis);
}
public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { public ExpandDims(SameDiff sameDiff, SDVariable[] args, boolean inPlace) {
super(null, sameDiff, args, inPlace); super(null, sameDiff, args, inPlace);
} }

View File

@ -122,6 +122,13 @@ public class Eye extends DynamicCustomOp {
addArgs(); addArgs();
} }
public Eye(SameDiff sameDiff, SDVariable numRows, SDVariable numCols, DataType dataType, int[] batchDimension) {
super(null, sameDiff, new SDVariable[] {numRows, numCols}, false);
this.batchDimension = batchDimension;
this.dataType = dataType;
addArgs();
}
protected void addArgs() { protected void addArgs() {
iArguments.clear(); iArguments.clear();
tArguments.clear(); tArguments.clear();

View File

@ -24,6 +24,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -40,6 +41,13 @@ public class Gather extends DynamicCustomOp {
protected int[] indices; protected int[] indices;
protected int jaxis = 0; protected int jaxis = 0;
public Gather(SameDiff sameDiff, SDVariable df, SDVariable indices, int axis) {
this(sameDiff, df, indices, axis, false);
}
public Gather(SameDiff sameDiff, SDVariable df, int[] indices, int axis) {
this(sameDiff, df, indices, axis, false);
}
public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) { public Gather(SameDiff sameDiff, SDVariable input, int[] indices, int axis, boolean inPlace) {
super(null, sameDiff, new SDVariable[] {input}, inPlace); super(null, sameDiff, new SDVariable[] {input}, inPlace);
@ -56,6 +64,21 @@ public class Gather extends DynamicCustomOp {
this.jaxis = axis; this.jaxis = axis;
} }
public Gather(INDArray df, int[] indexes, int axis) {
addInputArgument(df);
addIArgument(axis);
addIArgument(indexes);
this.jaxis = axis;
this.indices = indices;
}
public Gather(INDArray df, INDArray indexes, int axis) {
addInputArgument(df, indexes);
addIArgument(axis);
this.jaxis = axis;
this.indices = indices;
}
@Override @Override
public String onnxName() { public String onnxName() {
return "Gather"; return "Gather";

View File

@ -17,10 +17,13 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -31,11 +34,19 @@ import java.util.List;
@NoArgsConstructor @NoArgsConstructor
public class GatherNd extends DynamicCustomOp { public class GatherNd extends DynamicCustomOp {
public GatherNd(SameDiff sameDiff, SDVariable[] inputs, SDVariable[] indices) {
super(null, sameDiff, ArrayUtils.addAll(inputs, indices), false);
}
public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) { public GatherNd(SameDiff sameDiff, SDVariable input, SDVariable indices, boolean inPlace) {
super(null, sameDiff, new SDVariable[] {input, indices}, inPlace); super(null, sameDiff, new SDVariable[] {input, indices}, inPlace);
} }
public GatherNd(INDArray[] df, INDArray[] indices) {
addInputArgument(df);
addInputArgument(indices);
}
@Override @Override
public String opName() { public String opName() {
return "gather_nd"; return "gather_nd";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import org.apache.commons.lang3.NotImplementedException;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -39,11 +40,24 @@ public class Linspace extends DynamicCustomOp {
private DataType dataType; private DataType dataType;
public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) {
super(sameDiff, new SDVariable[0]);
addTArgument(start,stop);
addIArgument(number);
addDArgument(dataType);
}
public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){ public Linspace(SameDiff sameDiff, SDVariable from, SDVariable to, SDVariable length, DataType dataType){
super(sameDiff, new SDVariable[]{from, to, length}); super(sameDiff, new SDVariable[]{from, to, length});
this.dataType = dataType; this.dataType = dataType;
} }
public Linspace(DataType dataType, double start, double stop, long number) {
addDArgument(dataType);
addTArgument(start, stop);
addIArgument(number);
}
public Linspace(){ } public Linspace(){ }
@Override @Override

View File

@ -37,6 +37,10 @@ public class MeshGrid extends DynamicCustomOp {
addIArgument(cartesian ? 1 : 0); addIArgument(cartesian ? 1 : 0);
} }
public MeshGrid(SameDiff sd, SDVariable[] inputs, boolean cartesian) {
this(sd, cartesian, inputs);
}
public MeshGrid(){ } public MeshGrid(){ }
@Override @Override

View File

@ -66,6 +66,11 @@ public class OneHot extends DynamicCustomOp {
this(indices, output, depth, -1, 1, 0); this(indices, output, depth, -1, 1, 0);
} }
public OneHot(INDArray indices, int depth) {
addInputArgument(indices);
addIArgument(depth);
}
public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) { public OneHot(INDArray indices, INDArray output, int depth, int axis, double on, double off) {
super(null, indices, output, null, null); super(null, indices, output, null, null);
this.depth = depth; this.depth = depth;
@ -75,6 +80,12 @@ public class OneHot extends DynamicCustomOp {
addArgs(); addArgs();
} }
public OneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) {
addInputArgument(indices);
addIArgument(depth, axis);
addTArgument(on, off);
addDArgument(dataType);
}

View File

@ -48,10 +48,18 @@ public class OnesLike extends DynamicCustomOp {
public OnesLike() { public OnesLike() {
} }
public OnesLike(SameDiff sameDiff, SDVariable input) {
this(null, sameDiff, input);
}
public OnesLike(String name, SameDiff sameDiff, SDVariable input) { public OnesLike(String name, SameDiff sameDiff, SDVariable input) {
this(name, sameDiff, input, input.dataType()); this(name, sameDiff, input, input.dataType());
} }
public OnesLike(SameDiff sameDiff, SDVariable input, DataType dataType) {
this(null, sameDiff, input, dataType);
}
public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) { public OnesLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
super(name, sameDiff, new SDVariable[]{input}, false); super(name, sameDiff, new SDVariable[]{input}, false);
this.outputType = dataType; this.outputType = dataType;

View File

@ -55,6 +55,11 @@ public class Permute extends Transpose {
addIArgument(permuteDims); addIArgument(permuteDims);
} }
public Permute(INDArray input, int... permuteDims){
addInputArgument(input);
addIArgument(permuteDims);
}
public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){ public Permute(SameDiff sd, SDVariable input, SDVariable permuteDims){
super(sd, input, permuteDims); super(sd, input, permuteDims);
} }

View File

@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.*; import java.util.*;
@ -39,10 +40,18 @@ public class Rank extends DynamicCustomOp {
public Rank() { public Rank() {
} }
public Rank(SameDiff sameDiff, SDVariable input) {
this(sameDiff, input, false);
}
public Rank(SameDiff sameDiff, SDVariable input, boolean inPlace) { public Rank(SameDiff sameDiff, SDVariable input, boolean inPlace) {
super(null, sameDiff, new SDVariable[] {input}, inPlace); super(null, sameDiff, new SDVariable[] {input}, inPlace);
} }
public Rank(INDArray indArray) {
addInputArgument(indArray);
}
@Override @Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {

View File

@ -59,6 +59,10 @@ public class Reshape extends DynamicCustomOp {
super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null); super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null);
} }
public Reshape(INDArray in, INDArray shape) {
addInputArgument(in, shape);
}
public Reshape() { public Reshape() {
} }

View File

@ -71,6 +71,12 @@ public class SequenceMask extends DynamicCustomOp {
addDArgument(dataType); addDArgument(dataType);
} }
public SequenceMask(INDArray input, DataType dataType) {
addInputArgument(input);
this.dataType = dataType;
addDArgument(dataType);
}
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

Some files were not shown because too many files have changed in this diff Show More