diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 621dac941..1a40fbd11 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -215,6 +215,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
@@ -1628,6 +1629,13 @@ public class DifferentialFunctionFactory {
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
}
+ public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){
+ return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable();
+ }
+
+ public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){
+ return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables();
+ }
public SDVariable reshape(SDVariable iX, int[] shape) {
return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable();
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
index 0d2700b43..a97668e8e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
@@ -73,6 +73,8 @@ public class SDVariable implements Serializable {
@Getter
@Setter
protected WeightInitScheme weightInitScheme;
+
+ @Setter(AccessLevel.NONE)
protected long[] shape;
@Getter (AccessLevel.NONE)
@@ -237,6 +239,10 @@ public class SDVariable implements Serializable {
return initialShape;
}
+ public void setShape(long... shape){
+ this.shape = shape;
+ }
+
public long[] placeholderShape(){
if(variableType != VariableType.PLACEHOLDER){
throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName()
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
index 625912f50..ddd9ecbb2 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
@@ -3236,6 +3236,7 @@ public class SameDiff extends SDBaseOps {
/**
* See {@link #one(String, DataType, int...)}.
+ * Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable one(String name, int... shape) {
@@ -3244,6 +3245,7 @@ public class SameDiff extends SDBaseOps {
/**
* See {@link #one(String, DataType, long...)}.
+ * Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable one(String name, long... shape) {
@@ -3252,7 +3254,8 @@ public class SameDiff extends SDBaseOps {
/**
- * Create a new variable with the specified shape, with all values initialized to 1.0
+ * Create a new variable with the specified shape, with all values initialized to 1.0.
+ * Creates a VARIABLE type SDVariable.
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
@@ -3263,7 +3266,8 @@ public class SameDiff extends SDBaseOps {
}
/**
- * Create a new variable with the specified shape, with all values initialized to 1.0
+ * Create a new variable with the specified shape, with all values initialized to 1.0.
+ * Creates a VARIABLE type SDVariable.
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
@@ -3275,6 +3279,7 @@ public class SameDiff extends SDBaseOps {
/**
* See {@link #zero(String, DataType, long...)}.
+ * Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable zero(String name, long... shape) {
@@ -3283,6 +3288,7 @@ public class SameDiff extends SDBaseOps {
/**
* See {@link #zero(String, DataType, int...)}.
+ * Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/
public SDVariable zero(String name, int... shape) {
@@ -3290,7 +3296,8 @@ public class SameDiff extends SDBaseOps {
}
/**
- * Create a new variable with the specified shape, with all values initialized to 0
+ * Create a new variable with the specified shape, with all values initialized to 0.
+ * Creates a VARIABLE type SDVariable.
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
@@ -3301,7 +3308,8 @@ public class SameDiff extends SDBaseOps {
}
/**
- * Create a new variable with the specified shape, with all values initialized to 0
+ * Create a new variable with the specified shape, with all values initialized to 0.
+ * Creates a VARIABLE type SDVariable.
*
* @param name the name of the variable to create
* @param shape the shape of the array to be created
@@ -3522,6 +3530,19 @@ public class SameDiff extends SDBaseOps {
return var(name, Nd4j.defaultFloatingPointType(), shape);
}
+ /**
+ * Variable initialization with a specified {@link WeightInitScheme}. Data type will be given by {@link Nd4j#defaultFloatingPointType()}
+ * This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details.
+ *
+ * @param name the name of the variable
+ * @param shape the shape of the array to be created
+ * @param weightInitScheme the weight initialization scheme
+ * @return the created variable
+ */
+ public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull long... shape) {
+ return var(name, weightInitScheme, Nd4j.defaultFloatingPointType(), shape);
+ }
+
/**
* Creates a {@link SDVariable} with the given shape and name
* Any array will be generated with all zeros for the values
@@ -5223,7 +5244,7 @@ public class SameDiff extends SDBaseOps {
* @param variableName the vertex id for the original shape
* @param shape the shape of the place holder
*/
- public void setOriginalPlaceHolderShape(String variableName, long[] shape) {
+ public void setOriginalPlaceHolderShape(String variableName, @NonNull long... shape) {
if (!isPlaceHolder(variableName)) {
throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
index cd9d7ffd2..a17cb41b1 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
@@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops;
+import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
@@ -490,6 +491,34 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(res, name);
}
+ /**
+ * See {@link #prelu(String, SDVariable, SDVariable, int...)}.
+ */
+ public SDVariable prelu(@NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
+ return f().prelu(input, alpha, sharedAxes);
+ }
+
+ /**
+ * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
+ * out[i] = in[i] if in[i] >= 0
+ * out[i] = in[i] * alpha[i] otherwise
+ *
+ * sharedAxes allows you to share learnable parameters along axes.
+ * For example, if the input has shape [batchSize, channels, height, width]
+ * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
+ * alpha with shape [channels].
+ *
+ * @param name Name of the output variable
+ * @param input Input data
+ * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha.
+ * @param sharedAxes Which axes to share cutoff parameters along.
+ * @return Output variable
+ */
+ public SDVariable prelu(String name, @NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
+ SDVariable res = f().prelu(input, alpha, sharedAxes);
+ return updateVariableNameAndReference(res, name);
+ }
+
/**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
*
@@ -568,7 +597,7 @@ public class SDNN extends SDOps {
}
/**
- * Softmax activation
+ * Softmax activation on dimension 1.
*
* @param x Input variable
* @return Output variable
@@ -578,7 +607,7 @@ public class SDNN extends SDOps {
}
/**
- * Softmax activation
+ * Softmax activation on dimension 1.
*
* @param x Input variable
* @return Output variable
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
index f52fbc2d9..74c1d868d 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
@@ -894,6 +894,7 @@ public class OpValidation {
RationalTanhDerivative.class,
RectifiedTanhDerivative.class,
Relu6Derivative.class,
+ PReluBp.class,
SELUDerivative.class,
SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
index 19b534a97..fcf3fe630 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
@@ -231,6 +231,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
+ org.nd4j.linalg.api.ops.impl.scalar.PRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision.class,
@@ -434,6 +435,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative.class,
+ org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java
index a5008eb0f..5ae8f85ea 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java
@@ -16,11 +16,15 @@
package org.nd4j.linalg.api.ops.impl.image;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
@@ -36,8 +40,27 @@ import java.util.Map;
* ResizeBilinear op wrapper
* @author raver119@gmail.com
*/
+@NoArgsConstructor
public class ResizeBilinear extends DynamicCustomOp {
protected boolean alignCorners = false;
+ protected Integer height = null;
+ protected Integer width = null;
+
+ public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
+ super(sd, input);
+ this.alignCorners = alignCorners;
+ this.height = height;
+ this.width = width;
+ addArgs();
+ }
+
+ public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
+ super(new INDArray[]{x}, new INDArray[]{z});
+ this.alignCorners = alignCorners;
+ this.height = height;
+ this.width = width;
+ addArgs();
+ }
@Override
public String opName() {
@@ -60,13 +83,20 @@ public class ResizeBilinear extends DynamicCustomOp {
protected void addArgs() {
// to be implemented
iArguments.clear();
+ if(height != null && width != null){
+ iArguments.add(Long.valueOf(height));
+ iArguments.add(Long.valueOf(width));
+ }
iArguments.add(alignCorners ? 1L : 0L);
+
}
@Override
public Map propertiesForFunction() {
Map ret = new LinkedHashMap<>();
ret.put("alignCorners", alignCorners);
+ ret.put("height", height);
+ ret.put("width", width);
return ret;
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java
new file mode 100644
index 000000000..32c07ad96
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java
@@ -0,0 +1,81 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2018 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.nd4j.linalg.api.ops.impl.scalar;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
+import org.nd4j.imports.NoOpNameFoundException;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+
+/**
+ * Parameterized ReLU op
+ */
+@NoArgsConstructor
+public class PRelu extends DynamicCustomOp {
+
+ @Getter
+ protected int[] sharedAxes;
+
+ public PRelu(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable alpha, @NonNull int... sharedAxes) {
+ super(sameDiff, new SDVariable[]{x, alpha});
+ this.sharedAxes = sharedAxes;
+ addIArgument(sharedAxes);
+ }
+
+ public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
+ super(new INDArray[]{x, alpha}, new INDArray[]{z});
+ this.sharedAxes = sharedAxes;
+ addIArgument(sharedAxes);
+ }
+
+ @Override
+ public String opName() {
+ return "prelu";
+ }
+
+ @Override
+ public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName());
+ }
+
+ @Override
+ public String tensorflowName() {
+ throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
+ }
+
+ @Override
+ public List calculateOutputDataTypes(List dataTypes) {
+ Preconditions
+ .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
+ Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
+
+ return Collections.singletonList(dataTypes.get(0));
+ }
+
+ @Override
+ public List doDiff(List i_v) {
+ return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes));
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/PReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/PReluBp.java
new file mode 100644
index 000000000..703b44d61
--- /dev/null
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/PReluBp.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.nd4j.linalg.api.ops.impl.transforms.gradient;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.NonNull;
+import org.nd4j.autodiff.samediff.SDVariable;
+import org.nd4j.autodiff.samediff.SameDiff;
+import org.nd4j.base.Preconditions;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+
+/**
+ * PRelu backpropagation op - dL/dIn from in and dL/dOut
+ */
+@NoArgsConstructor
+public class PReluBp extends DynamicCustomOp {
+
+ @Getter
+ protected int[] sharedAxes;
+
+ public PReluBp(SameDiff sd, SDVariable input, SDVariable alpha, SDVariable gradient, int... sharedAxes){
+ super(sd, new SDVariable[]{input, alpha, gradient});
+ this.sharedAxes = sharedAxes;
+ addIArgument(sharedAxes);
+ }
+
+ public PReluBp(@NonNull INDArray input, @NonNull INDArray alpha, @NonNull INDArray gradient, INDArray dLdI, INDArray dLdA, int... sharedAxes){
+ super(new INDArray[]{input, alpha, gradient}, wrapFilterNull(dLdI, dLdA));
+ this.sharedAxes = sharedAxes;
+ addIArgument(sharedAxes);
+ }
+
+ @Override
+ public String opName(){
+ return "prelu_bp";
+ }
+
+ @Override
+ public List calculateOutputDataTypes(List dataTypes) {
+ Preconditions
+ .checkArgument(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes, got %s", dataTypes);
+ Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(2).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
+
+ return Arrays.asList(dataTypes.get(0), dataTypes.get(1));
+ }
+
+ @Override
+ public List doDiff(List f1) {
+ throw new UnsupportedOperationException("Not supported");
+ }
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
index 8656cb46f..7d17b3604 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
@@ -3585,4 +3585,28 @@ public class SameDiffTests extends BaseNd4jTest {
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
}
}
+
+ @Test
+ public void testPReLU(){
+ SameDiff sd = SameDiff.create();
+
+ SDVariable input = sd.constant(Nd4j.createFromArray(
+ new int[][][]{{
+ {-10, 10, 10, -10},
+ {10, 10, -10, -10}
+ }}
+ ).castTo(DataType.DOUBLE));
+
+ SDVariable alpha = sd.var(Nd4j.createFromArray(0.01, 0.1).castTo(DataType.DOUBLE));
+
+ SDVariable out = sd.nn.prelu("out", input, alpha, 2);
+
+ TestCase tc = new TestCase(sd).expected("out", Nd4j.createFromArray(new double[][][]{{
+ {-0.1, 10, 10, -0.1},
+ {10, 10, -1, -1}
+ }}).castTo(DataType.DOUBLE)).gradientCheck(true);
+
+ String err = OpValidation.validate(tc);
+ assertNull(err);
+ }
}