SameDiff ops (#8247)

* update javadocs and a few method signatures

Signed-off-by: Ryan Nett <rnett@skymind.io>

* add PRelu op

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test and fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* add PRelu op

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test and fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* slightly better test

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Alex Black 2019-09-19 11:52:20 +10:00 committed by GitHub
parent 59f1cbf0c6
commit f98f8be7b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 280 additions and 7 deletions

View File

@ -215,6 +215,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
@ -1628,6 +1629,13 @@ public class DifferentialFunctionFactory {
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
} }
public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){
return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable();
}
public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){
return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables();
}
public SDVariable reshape(SDVariable iX, int[] shape) { public SDVariable reshape(SDVariable iX, int[] shape) {
return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable(); return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable();

View File

@ -73,6 +73,8 @@ public class SDVariable implements Serializable {
@Getter @Getter
@Setter @Setter
protected WeightInitScheme weightInitScheme; protected WeightInitScheme weightInitScheme;
@Setter(AccessLevel.NONE)
protected long[] shape; protected long[] shape;
@Getter (AccessLevel.NONE) @Getter (AccessLevel.NONE)
@ -237,6 +239,10 @@ public class SDVariable implements Serializable {
return initialShape; return initialShape;
} }
public void setShape(long... shape){
this.shape = shape;
}
public long[] placeholderShape(){ public long[] placeholderShape(){
if(variableType != VariableType.PLACEHOLDER){ if(variableType != VariableType.PLACEHOLDER){
throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName() throw new IllegalStateException("placeholderShape() can only be used for placeholder variables: variable \"" + getVarName()

View File

@ -3236,6 +3236,7 @@ public class SameDiff extends SDBaseOps {
/** /**
* See {@link #one(String, DataType, int...)}. * See {@link #one(String, DataType, int...)}.
* Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/ */
public SDVariable one(String name, int... shape) { public SDVariable one(String name, int... shape) {
@ -3244,6 +3245,7 @@ public class SameDiff extends SDBaseOps {
/** /**
* See {@link #one(String, DataType, long...)}. * See {@link #one(String, DataType, long...)}.
* Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/ */
public SDVariable one(String name, long... shape) { public SDVariable one(String name, long... shape) {
@ -3252,7 +3254,8 @@ public class SameDiff extends SDBaseOps {
/** /**
* Create a new variable with the specified shape, with all values initialized to 1.0 * Create a new variable with the specified shape, with all values initialized to 1.0.
* Creates a VARIABLE type SDVariable.
* *
* @param name the name of the variable to create * @param name the name of the variable to create
* @param shape the shape of the array to be created * @param shape the shape of the array to be created
@ -3263,7 +3266,8 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* Create a new variable with the specified shape, with all values initialized to 1.0 * Create a new variable with the specified shape, with all values initialized to 1.0.
* Creates a VARIABLE type SDVariable.
* *
* @param name the name of the variable to create * @param name the name of the variable to create
* @param shape the shape of the array to be created * @param shape the shape of the array to be created
@ -3275,6 +3279,7 @@ public class SameDiff extends SDBaseOps {
/** /**
* See {@link #zero(String, DataType, long...)}. * See {@link #zero(String, DataType, long...)}.
* Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/ */
public SDVariable zero(String name, long... shape) { public SDVariable zero(String name, long... shape) {
@ -3283,6 +3288,7 @@ public class SameDiff extends SDBaseOps {
/** /**
* See {@link #zero(String, DataType, int...)}. * See {@link #zero(String, DataType, int...)}.
* Creates a VARIABLE type SDVariable.
* Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}). * Uses the DataType of the Nd4j default floating point type ({@link Nd4j#defaultFloatingPointType()}).
*/ */
public SDVariable zero(String name, int... shape) { public SDVariable zero(String name, int... shape) {
@ -3290,7 +3296,8 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* Create a new variable with the specified shape, with all values initialized to 0 * Create a new variable with the specified shape, with all values initialized to 0.
* Creates a VARIABLE type SDVariable.
* *
* @param name the name of the variable to create * @param name the name of the variable to create
* @param shape the shape of the array to be created * @param shape the shape of the array to be created
@ -3301,7 +3308,8 @@ public class SameDiff extends SDBaseOps {
} }
/** /**
* Create a new variable with the specified shape, with all values initialized to 0 * Create a new variable with the specified shape, with all values initialized to 0.
* Creates a VARIABLE type SDVariable.
* *
* @param name the name of the variable to create * @param name the name of the variable to create
* @param shape the shape of the array to be created * @param shape the shape of the array to be created
@ -3522,6 +3530,19 @@ public class SameDiff extends SDBaseOps {
return var(name, Nd4j.defaultFloatingPointType(), shape); return var(name, Nd4j.defaultFloatingPointType(), shape);
} }
/**
* Variable initialization with a specified {@link WeightInitScheme}. Data type will be given by {@link Nd4j#defaultFloatingPointType()}<br>
* This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See {@link VariableType} for more details.
*
* @param name the name of the variable
* @param shape the shape of the array to be created
* @param weightInitScheme the weight initialization scheme
* @return the created variable
*/
public SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull long... shape) {
return var(name, weightInitScheme, Nd4j.defaultFloatingPointType(), shape);
}
/** /**
* Creates a {@link SDVariable} with the given shape and name<br> * Creates a {@link SDVariable} with the given shape and name<br>
* Any array will be generated with all zeros for the values<br> * Any array will be generated with all zeros for the values<br>
@ -5223,7 +5244,7 @@ public class SameDiff extends SDBaseOps {
* @param variableName the vertex id for the original shape * @param variableName the vertex id for the original shape
* @param shape the shape of the place holder * @param shape the shape of the place holder
*/ */
public void setOriginalPlaceHolderShape(String variableName, long[] shape) { public void setOriginalPlaceHolderShape(String variableName, @NonNull long... shape) {
if (!isPlaceHolder(variableName)) { if (!isPlaceHolder(variableName)) {
throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?"); throw new ND4JIllegalStateException("Vertex id " + variableName + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.ops; package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.Pad;
@ -490,6 +491,34 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(res, name); return updateVariableNameAndReference(res, name);
} }
/**
* See {@link #prelu(String, SDVariable, SDVariable, int...)}.
*/
public SDVariable prelu(@NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
return f().prelu(input, alpha, sharedAxes);
}
/**
* PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:<br>
* out[i] = in[i] if in[i] >= 0<br>
* out[i] = in[i] * alpha[i] otherwise<br>
*
* sharedAxes allows you to share learnable parameters along axes.
* For example, if the input has shape [batchSize, channels, height, width]
* and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
* alpha with shape [channels].
*
* @param name Name of the output variable
* @param input Input data
* @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha.
* @param sharedAxes Which axes to share cutoff parameters along.
* @return Output variable
*/
public SDVariable prelu(String name, @NonNull SDVariable input, @NonNull SDVariable alpha, @NonNull int... sharedAxes){
SDVariable res = f().prelu(input, alpha, sharedAxes);
return updateVariableNameAndReference(res, name);
}
/** /**
* Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a> * Element-wise SeLU function - Scaled exponential Lineal Unit: see <a href="https://arxiv.org/abs/1706.02515">Self-Normalizing Neural Networks</a>
* <br> * <br>
@ -568,7 +597,7 @@ public class SDNN extends SDOps {
} }
/** /**
* Softmax activation * Softmax activation on dimension 1.
* *
* @param x Input variable * @param x Input variable
* @return Output variable * @return Output variable
@ -578,7 +607,7 @@ public class SDNN extends SDOps {
} }
/** /**
* Softmax activation * Softmax activation on dimension 1.
* *
* @param x Input variable * @param x Input variable
* @return Output variable * @return Output variable

View File

@ -894,6 +894,7 @@ public class OpValidation {
RationalTanhDerivative.class, RationalTanhDerivative.class,
RectifiedTanhDerivative.class, RectifiedTanhDerivative.class,
Relu6Derivative.class, Relu6Derivative.class,
PReluBp.class,
SELUDerivative.class, SELUDerivative.class,
SigmoidDerivative.class, SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,

View File

@ -231,6 +231,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
org.nd4j.linalg.api.ops.impl.scalar.PRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision.class,
@ -434,6 +435,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,

View File

@ -16,11 +16,15 @@
package org.nd4j.linalg.api.ops.impl.image; package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
@ -36,8 +40,27 @@ import java.util.Map;
* ResizeBilinear op wrapper * ResizeBilinear op wrapper
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@NoArgsConstructor
public class ResizeBilinear extends DynamicCustomOp { public class ResizeBilinear extends DynamicCustomOp {
protected boolean alignCorners = false; protected boolean alignCorners = false;
protected Integer height = null;
protected Integer width = null;
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
super(sd, input);
this.alignCorners = alignCorners;
this.height = height;
this.width = width;
addArgs();
}
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
super(new INDArray[]{x}, new INDArray[]{z});
this.alignCorners = alignCorners;
this.height = height;
this.width = width;
addArgs();
}
@Override @Override
public String opName() { public String opName() {
@ -60,13 +83,20 @@ public class ResizeBilinear extends DynamicCustomOp {
protected void addArgs() { protected void addArgs() {
// to be implemented // to be implemented
iArguments.clear(); iArguments.clear();
if(height != null && width != null){
iArguments.add(Long.valueOf(height));
iArguments.add(Long.valueOf(width));
}
iArguments.add(alignCorners ? 1L : 0L); iArguments.add(alignCorners ? 1L : 0L);
} }
@Override @Override
public Map<String, Object> propertiesForFunction() { public Map<String, Object> propertiesForFunction() {
Map<String,Object> ret = new LinkedHashMap<>(); Map<String,Object> ret = new LinkedHashMap<>();
ret.put("alignCorners", alignCorners); ret.put("alignCorners", alignCorners);
ret.put("height", height);
ret.put("width", width);
return ret; return ret;
} }

View File

@ -0,0 +1,81 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.scalar;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Parameterized ReLU op
*/
@NoArgsConstructor
public class PRelu extends DynamicCustomOp {
@Getter
protected int[] sharedAxes;
public PRelu(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable alpha, @NonNull int... sharedAxes) {
super(sameDiff, new SDVariable[]{x, alpha});
this.sharedAxes = sharedAxes;
addIArgument(sharedAxes);
}
public PRelu(@NonNull INDArray x, INDArray z, @NonNull INDArray alpha, @NonNull int... sharedAxes) {
super(new INDArray[]{x, alpha}, new INDArray[]{z});
this.sharedAxes = sharedAxes;
addIArgument(sharedAxes);
}
@Override
public String opName() {
return "prelu";
}
@Override
public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes));
}
}

View File

@ -0,0 +1,71 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* PRelu backpropagation op - dL/dIn from in and dL/dOut
*/
@NoArgsConstructor
public class PReluBp extends DynamicCustomOp {
@Getter
protected int[] sharedAxes;
public PReluBp(SameDiff sd, SDVariable input, SDVariable alpha, SDVariable gradient, int... sharedAxes){
super(sd, new SDVariable[]{input, alpha, gradient});
this.sharedAxes = sharedAxes;
addIArgument(sharedAxes);
}
public PReluBp(@NonNull INDArray input, @NonNull INDArray alpha, @NonNull INDArray gradient, INDArray dLdI, INDArray dLdA, int... sharedAxes){
super(new INDArray[]{input, alpha, gradient}, wrapFilterNull(dLdI, dLdA));
this.sharedAxes = sharedAxes;
addIArgument(sharedAxes);
}
@Override
public String opName(){
return "prelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(2).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Arrays.asList(dataTypes.get(0), dataTypes.get(1));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -3585,4 +3585,28 @@ public class SameDiffTests extends BaseNd4jTest {
assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0")); assertTrue(m, m.contains("variable") && m.contains("empty") && m.contains("0"));
} }
} }
@Test
public void testPReLU(){
SameDiff sd = SameDiff.create();
SDVariable input = sd.constant(Nd4j.createFromArray(
new int[][][]{{
{-10, 10, 10, -10},
{10, 10, -10, -10}
}}
).castTo(DataType.DOUBLE));
SDVariable alpha = sd.var(Nd4j.createFromArray(0.01, 0.1).castTo(DataType.DOUBLE));
SDVariable out = sd.nn.prelu("out", input, alpha, 2);
TestCase tc = new TestCase(sd).expected("out", Nd4j.createFromArray(new double[][][]{{
{-0.1, 10, 10, -0.1},
{10, 10, -1, -1}
}}).castTo(DataType.DOUBLE)).gradientCheck(true);
String err = OpValidation.validate(tc);
assertNull(err);
}
} }