SameDiff ops (#8247)
* update javadocs and a few method signatures Signed-off-by: Ryan Nett <rnett@skymind.io> * add PRelu op Signed-off-by: Ryan Nett <rnett@skymind.io> * test and fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * add PRelu op Signed-off-by: Ryan Nett <rnett@skymind.io> * test and fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * slightly better test Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
59f1cbf0c6
commit
f98f8be7b6
|
@ -215,6 +215,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||||
|
@ -1628,6 +1629,13 @@ public class DifferentialFunctionFactory {
|
||||||
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
|
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){
|
||||||
|
return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){
|
||||||
|
return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable reshape(SDVariable iX, int[] shape) {
|
public SDVariable reshape(SDVariable iX, int[] shape) {
|
||||||
return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable();
|
return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable();
|
||||||
|
|
|
@ -73,6 +73,8 @@ public class SDVariable implements Serializable {
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
protected WeightInitScheme weightInitScheme;
|
protected WeightInitScheme weightInitScheme;
|
||||||
|
|
||||||
|
@Setter(AccessLevel.NONE)
|
||||||
protected long[] shape;
|
protected long[] shape;
|
||||||
|
|
||||||
@Getter (AccessLevel.NONE)
|
@Getter (AccessLevel.NONE)
|
||||||
|
@ -237,6 +239,10 @@ public class SDVariable implements Serializable {
|
||||||
return initialShape;
|
return initialShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setShape(long... shape){
|
||||||
|
this.shape = shape;
|
||||||
|
}
|
||||||
|
|
||||||
public long[] placeholderShape(){
|
public long[] placeholderShape(){
|
||||||
if(variableType != VariableType.PLACEHOLDER){
|
if(variableType != VariableType.PLACEHOLDER){
|
||||||
throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName()
|
throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName()
|
||||||
|
|
|
@ -3236,6 +3236,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #one(String, DataType, int...)}.
|
* See {@link #one(String, DataType, int...)}.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
*/
|
*/
|
||||||
public SDVariable one(String name, int... shape) {
|
public SDVariable one(String name, int... shape) {
|
||||||
|
@ -3244,6 +3245,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #one(String, DataType, long...)}.
|
* See {@link #one(String, DataType, long...)}.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
*/
|
*/
|
||||||
public SDVariable one(String name, long... shape) {
|
public SDVariable one(String name, long... shape) {
|
||||||
|
@ -3252,7 +3254,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new variable with the specified shape, with all values initialized to 1.0
|
* Create a new variable with the specified shape, with all values initialized to 1.0.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
*
|
*
|
||||||
* @param name the name of the variable to create
|
* @param name the name of the variable to create
|
||||||
* @param shape the shape of the array to be created
|
* @param shape the shape of the array to be created
|
||||||
|
@ -3263,7 +3266,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new variable with the specified shape, with all values initialized to 1.0
|
* Create a new variable with the specified shape, with all values initialized to 1.0.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
*
|
*
|
||||||
* @param name the name of the variable to create
|
* @param name the name of the variable to create
|
||||||
* @param shape the shape of the array to be created
|
* @param shape the shape of the array to be created
|
||||||
|
@ -3275,6 +3279,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #zero(String, DataType, long...)}.
|
* See {@link #zero(String, DataType, long...)}.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
*/
|
*/
|
||||||
public SDVariable zero(String name, long... shape) {
|
public SDVariable zero(String name, long... shape) {
|
||||||
|
@ -3283,6 +3288,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #zero(String, DataType, int...)}.
|
* See {@link #zero(String, DataType, int...)}.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
|
||||||
*/
|
*/
|
||||||
public SDVariable zero(String name, int... shape) {
|
public SDVariable zero(String name, int... shape) {
|
||||||
|
@ -3290,7 +3296,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new variable with the specified shape, with all values initialized to 0
|
* Create a new variable with the specified shape, with all values initialized to 0.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
*
|
*
|
||||||
* @param name the name of the variable to create
|
* @param name the name of the variable to create
|
||||||
* @param shape the shape of the array to be created
|
* @param shape the shape of the array to be created
|
||||||
|
@ -3301,7 +3308,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a new variable with the specified shape, with all values initialized to 0
|
* Create a new variable with the specified shape, with all values initialized to 0.
|
||||||
|
* Creates a VARIABLE type SDVariable.
|
||||||
*
|
*
|
||||||
* @param name the name of the variable to create
|
* @param name the name of the variable to create
|
||||||
* @param shape the shape of the array to be created
|
* @param shape the shape of the array to be created
|
||||||
|
@ -3522,6 +3530,19 @@ public class SameDiff extends SDBaseOps {
|
||||||
return var(name, Nd4j.defaultFloatingPointType(), shape);
|
return var(name, Nd4j.defaultFloatingPointType(), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Variable initialization with a specified {@link WeightInitScheme}. Data type will be given by {@link Nd4j#defaultFloatingPointType()}<br>
|
||||||
|
* This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details.
|
||||||
|
*
|
||||||
|
* @param name the name of the variable
|
||||||
|
* @param shape the shape of the array to be created
|
||||||
|
* @param weightInitScheme the weight initialization scheme
|
||||||
|
* @return the created variable
|
||||||
|
*/
|
||||||
|
public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull long... shape) {
|
||||||
|
return var(name, weightInitScheme, Nd4j.defaultFloatingPointType(), shape);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a {@link SDVariable} with the given shape and name<br>
|
* Creates a {@link SDVariable} with the given shape and name<br>
|
||||||
* Any array will be generated with all zeros for the values<br>
|
* Any array will be generated with all zeros for the values<br>
|
||||||
|
@ -5223,7 +5244,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @param variableName the vertex id for the original shape
|
* @param variableName the vertex id for the original shape
|
||||||
* @param shape the shape of the place holder
|
* @param shape the shape of the place holder
|
||||||
*/
|
*/
|
||||||
public void setOriginalPlaceHolderShape(String variableName, long[] shape) {
|
public void setOriginalPlaceHolderShape(String variableName, @NonNull long... shape) {
|
||||||
if (!isPlaceHolder(variableName)) {
|
if (!isPlaceHolder(variableName)) {
|
||||||
throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
|
throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.samediff.ops;
|
package org.nd4j.autodiff.samediff.ops;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||||
|
@ -490,6 +491,34 @@ public class SDNN extends SDOps {
|
||||||
return updateVariableNameAndReference(res, name);
|
return updateVariableNameAndReference(res, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #prelu(String, SDVariable, SDVariable, int...)}.
|
||||||
|
*/
|
||||||
|
public SDVariable prelu(@NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
|
||||||
|
return f().prelu(input, alpha, sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
|
||||||
|
* out[i] = in[i] if in[i] >= 0<br>
|
||||||
|
* out[i] = in[i] * alpha[i] otherwise<br>
|
||||||
|
*
|
||||||
|
* sharedAxes allows you to share learnable parameters along axes.
|
||||||
|
* For example, if the input has shape [batchSize, channels, height, width]
|
||||||
|
* and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
|
||||||
|
* alpha with shape [channels].
|
||||||
|
*
|
||||||
|
* @param name Name of the output variable
|
||||||
|
* @param input Input data
|
||||||
|
* @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha.
|
||||||
|
* @param sharedAxes Which axes to share cutoff parameters along.
|
||||||
|
* @return Output variable
|
||||||
|
*/
|
||||||
|
public SDVariable prelu(String name, @NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
|
||||||
|
SDVariable res = f().prelu(input, alpha, sharedAxes);
|
||||||
|
return updateVariableNameAndReference(res, name);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a>
|
* Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a>
|
||||||
* <br>
|
* <br>
|
||||||
|
@ -568,7 +597,7 @@ public class SDNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Softmax activation
|
* Softmax activation on dimension 1.
|
||||||
*
|
*
|
||||||
* @param x Input variable
|
* @param x Input variable
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
|
@ -578,7 +607,7 @@ public class SDNN extends SDOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Softmax activation
|
* Softmax activation on dimension 1.
|
||||||
*
|
*
|
||||||
* @param x Input variable
|
* @param x Input variable
|
||||||
* @return Output variable
|
* @return Output variable
|
||||||
|
|
|
@ -894,6 +894,7 @@ public class OpValidation {
|
||||||
RationalTanhDerivative.class,
|
RationalTanhDerivative.class,
|
||||||
RectifiedTanhDerivative.class,
|
RectifiedTanhDerivative.class,
|
||||||
Relu6Derivative.class,
|
Relu6Derivative.class,
|
||||||
|
PReluBp.class,
|
||||||
SELUDerivative.class,
|
SELUDerivative.class,
|
||||||
SigmoidDerivative.class,
|
SigmoidDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
|
||||||
|
|
|
@ -231,6 +231,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.scalar.PRelu.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision.class,
|
||||||
|
@ -434,6 +435,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
|
||||||
|
|
|
@ -16,11 +16,15 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.image;
|
package org.nd4j.linalg.api.ops.impl.image;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
@ -36,8 +40,27 @@ import java.util.Map;
|
||||||
* ResizeBilinear op wrapper
|
* ResizeBilinear op wrapper
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
public class ResizeBilinear extends DynamicCustomOp {
|
public class ResizeBilinear extends DynamicCustomOp {
|
||||||
protected boolean alignCorners = false;
|
protected boolean alignCorners = false;
|
||||||
|
protected Integer height = null;
|
||||||
|
protected Integer width = null;
|
||||||
|
|
||||||
|
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
|
||||||
|
super(sd, input);
|
||||||
|
this.alignCorners = alignCorners;
|
||||||
|
this.height = height;
|
||||||
|
this.width = width;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
|
||||||
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
|
this.alignCorners = alignCorners;
|
||||||
|
this.height = height;
|
||||||
|
this.width = width;
|
||||||
|
addArgs();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
@ -60,13 +83,20 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
protected void addArgs() {
|
protected void addArgs() {
|
||||||
// to be implemented
|
// to be implemented
|
||||||
iArguments.clear();
|
iArguments.clear();
|
||||||
|
if(height != null && width != null){
|
||||||
|
iArguments.add(Long.valueOf(height));
|
||||||
|
iArguments.add(Long.valueOf(width));
|
||||||
|
}
|
||||||
iArguments.add(alignCorners ? 1L : 0L);
|
iArguments.add(alignCorners ? 1L : 0L);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Object> propertiesForFunction() {
|
public Map<String, Object> propertiesForFunction() {
|
||||||
Map<String,Object> ret = new LinkedHashMap<>();
|
Map<String,Object> ret = new LinkedHashMap<>();
|
||||||
ret.put("alignCorners", alignCorners);
|
ret.put("alignCorners", alignCorners);
|
||||||
|
ret.put("height", height);
|
||||||
|
ret.put("width", width);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.scalar;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parameterized ReLU op
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class PRelu extends DynamicCustomOp {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
protected int[] sharedAxes;
|
||||||
|
|
||||||
|
public PRelu(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable alpha, @NonNull int... sharedAxes) {
|
||||||
|
super(sameDiff, new SDVariable[]{x, alpha});
|
||||||
|
this.sharedAxes = sharedAxes;
|
||||||
|
addIArgument(sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
|
||||||
|
super(new INDArray[]{x, alpha}, new INDArray[]{z});
|
||||||
|
this.sharedAxes = sharedAxes;
|
||||||
|
addIArgument(sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "prelu";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
|
return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* PRelu backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class PReluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
protected int[] sharedAxes;
|
||||||
|
|
||||||
|
public PReluBp(SameDiff sd, SDVariable input, SDVariable alpha, SDVariable gradient, int... sharedAxes){
|
||||||
|
super(sd, new SDVariable[]{input, alpha, gradient});
|
||||||
|
this.sharedAxes = sharedAxes;
|
||||||
|
addIArgument(sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PReluBp(@NonNull INDArray input, @NonNull INDArray alpha, @NonNull INDArray gradient, INDArray dLdI, INDArray dLdA, int... sharedAxes){
|
||||||
|
super(new INDArray[]{input, alpha, gradient}, wrapFilterNull(dLdI, dLdA));
|
||||||
|
this.sharedAxes = sharedAxes;
|
||||||
|
addIArgument(sharedAxes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "prelu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(2).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Arrays.asList(dataTypes.get(0), dataTypes.get(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -3585,4 +3585,28 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPReLU(){
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
|
||||||
|
SDVariable input = sd.constant(Nd4j.createFromArray(
|
||||||
|
new int[][][]{{
|
||||||
|
{-10, 10, 10, -10},
|
||||||
|
{10, 10, -10, -10}
|
||||||
|
}}
|
||||||
|
).castTo(DataType.DOUBLE));
|
||||||
|
|
||||||
|
SDVariable alpha = sd.var(Nd4j.createFromArray(0.01, 0.1).castTo(DataType.DOUBLE));
|
||||||
|
|
||||||
|
SDVariable out = sd.nn.prelu("out", input, alpha, 2);
|
||||||
|
|
||||||
|
TestCase tc = new TestCase(sd).expected("out", Nd4j.createFromArray(new double[][][]{{
|
||||||
|
{-0.1, 10, 10, -0.1},
|
||||||
|
{10, 10, -1, -1}
|
||||||
|
}}).castTo(DataType.DOUBLE)).gradientCheck(true);
|
||||||
|
|
||||||
|
String err = OpValidation.validate(tc);
|
||||||
|
assertNull(err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue