parent
e42c34ca55
commit
82c9dc5743
|
@ -207,7 +207,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||||
|
@ -1567,14 +1566,6 @@ public class DifferentialFunctionFactory {
|
||||||
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #eluBp(SDVariable, SDVariable)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public SDVariable eluDerivative(SDVariable iX) {
|
|
||||||
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SDVariable leakyRelu(SDVariable iX, double alpha) {
|
public SDVariable leakyRelu(SDVariable iX, double alpha) {
|
||||||
return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable();
|
return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable();
|
||||||
|
|
|
@ -163,31 +163,6 @@ public class SDNN extends SDOps {
|
||||||
return updateVariableNameAndReference(result, name);
|
return updateVariableNameAndReference(result, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
|
|
||||||
* {@link #elu(SDVariable)}
|
|
||||||
*
|
|
||||||
* @param x Input variable
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable eluDerivative(SDVariable x) {
|
|
||||||
return eluDerivative(null, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
|
|
||||||
* {@link #elu(SDVariable)}
|
|
||||||
*
|
|
||||||
* @param name Output variable name
|
|
||||||
* @param x Input variable
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable eluDerivative(String name, SDVariable x) {
|
|
||||||
validateFloatingPoint("eluDerivative", x);
|
|
||||||
SDVariable result = f().eluDerivative(x);
|
|
||||||
return updateVariableNameAndReference(result, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* GELU activation function - Gaussian Error Linear Units<br>
|
* GELU activation function - Gaussian Error Linear Units<br>
|
||||||
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
|
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
|
||||||
|
|
|
@ -255,8 +255,6 @@ public class LegacyOpMapper {
|
||||||
return Abs.class;
|
return Abs.class;
|
||||||
case 2:
|
case 2:
|
||||||
return LogSoftMax.class;
|
return LogSoftMax.class;
|
||||||
case 3:
|
|
||||||
return ELUDerivative.class;
|
|
||||||
case 4:
|
case 4:
|
||||||
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
|
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
|
||||||
case 5:
|
case 5:
|
||||||
|
|
|
@ -881,7 +881,6 @@ public class OpValidation {
|
||||||
SoftmaxBp.class,
|
SoftmaxBp.class,
|
||||||
|
|
||||||
CubeDerivative.class,
|
CubeDerivative.class,
|
||||||
ELUDerivative.class,
|
|
||||||
GELUDerivative.class,
|
GELUDerivative.class,
|
||||||
PreciseGELUDerivative.class,
|
PreciseGELUDerivative.class,
|
||||||
HardSigmoidDerivative.class,
|
HardSigmoidDerivative.class,
|
||||||
|
|
|
@ -422,7 +422,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
|
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
|
||||||
|
|
|
@ -23,7 +23,6 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
@ -75,20 +74,8 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
// no support in ELU native to override alpha
|
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
|
||||||
if (alpha != 1.00) {
|
return new Pair<>(in, null);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup()));
|
|
||||||
dLdz.muli(alpha);
|
|
||||||
BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
|
|
||||||
|
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
|
|
||||||
return new Pair<>(in, null);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Derivative of ELU: Exponential Linear Unit (alpha=1.0)<br>
|
|
||||||
* Introduced in paper:<br>
|
|
||||||
* Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)<br>
|
|
||||||
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
|
|
||||||
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
|
|
||||||
*
|
|
||||||
* @deprecated Use {@link EluBp}
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public class ELUDerivative extends BaseTransformStrictOp {
|
|
||||||
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "eluderivative";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable ret = sameDiff.zerosLike(arg());
|
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -37,8 +37,13 @@ public class EluBp extends DynamicCustomOp {
|
||||||
super(sd, new SDVariable[]{input, gradient});
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
}
|
}
|
||||||
|
|
||||||
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
||||||
|
this(input, gradient, output, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
|
||||||
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
addTArgument(alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -71,11 +71,6 @@ public class ELU extends DynamicCustomOp {
|
||||||
return "Elu";
|
return "Elu";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
|
||||||
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//ELU: e^x-1 if x<0, x otherwise
|
//ELU: e^x-1 if x<0, x otherwise
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||||
|
@ -441,13 +441,13 @@ public class Transforms {
|
||||||
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
|
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray eluDerivative(INDArray arr) {
|
public static INDArray eluDerivative(INDArray arr, INDArray grad) {
|
||||||
return eluDerivative(arr, true);
|
return eluDerivative(arr, grad,true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static INDArray eluDerivative(INDArray in, boolean copy) {
|
public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,8 @@ public class TFGraphTestList {
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
public static String[] modelNames = new String[]{
|
public static String[] modelNames = new String[]{
|
||||||
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
|
// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
|
||||||
|
"cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu"
|
||||||
};
|
};
|
||||||
|
|
||||||
@After
|
@After
|
||||||
|
|
|
@ -305,44 +305,6 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testELUDerivative() {
|
|
||||||
|
|
||||||
//f(x) = x if x>=0
|
|
||||||
//f(x) = 1.0*(exp(x)-1)
|
|
||||||
INDArray z = Nd4j.zeros(100);
|
|
||||||
double[] out = new double[100];
|
|
||||||
double[] outDeriv = new double[100];
|
|
||||||
for (int i = 0; i < 100; i++) {
|
|
||||||
double x = 0.1 * (i - 50);
|
|
||||||
z.putScalar(i, x);
|
|
||||||
if (x >= 0) {
|
|
||||||
out[i] = x;
|
|
||||||
outDeriv[i] = 1.0;
|
|
||||||
} else {
|
|
||||||
out[i] = FastMath.exp(x) - 1.0;
|
|
||||||
outDeriv[i] = FastMath.exp(x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray act = Transforms.elu(z, true);
|
|
||||||
INDArray actDeriv = Nd4j.getExecutioner().exec(new ELUDerivative(z.dup()));
|
|
||||||
|
|
||||||
System.out.println(act);
|
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
|
||||||
double relError1 = Math.abs(out[i] - act.getDouble(i)) / (Math.abs(out[i]) + Math.abs(act.getDouble(i)));
|
|
||||||
if (out[i] == 0.0 && act.getDouble(i) == 0.0)
|
|
||||||
relError1 = 0.0;
|
|
||||||
double relError2 = Math.abs(outDeriv[i] - actDeriv.getDouble(i))
|
|
||||||
/ (Math.abs(outDeriv[i]) + Math.abs(actDeriv.getDouble(i)));
|
|
||||||
if (outDeriv[i] == 0.0 && actDeriv.getDouble(i) == 0.0)
|
|
||||||
relError2 = 0.0;
|
|
||||||
assertTrue(relError1 < REL_ERROR_TOLERANCE);
|
|
||||||
assertTrue(relError2 < REL_ERROR_TOLERANCE);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'f';
|
return 'f';
|
||||||
|
|
Loading…
Reference in New Issue