Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-09-02 18:37:05 +10:00 committed by GitHub
parent e42c34ca55
commit 82c9dc5743
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 16 additions and 191 deletions

View File

@ -207,7 +207,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
@ -1567,14 +1566,6 @@ public class DifferentialFunctionFactory {
return new EluBp(sameDiff(), in, epsilon).outputVariable(); return new EluBp(sameDiff(), in, epsilon).outputVariable();
} }
/**
* @deprecated Use {@link #eluBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable eluDerivative(SDVariable iX) {
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
}
public SDVariable leakyRelu(SDVariable iX, double alpha) { public SDVariable leakyRelu(SDVariable iX, double alpha) {
return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable(); return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable();

View File

@ -163,31 +163,6 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(result, name); return updateVariableNameAndReference(result, name);
} }
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(SDVariable x) {
return eluDerivative(null, x);
}
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(String name, SDVariable x) {
validateFloatingPoint("eluDerivative", x);
SDVariable result = f().eluDerivative(x);
return updateVariableNameAndReference(result, name);
}
/** /**
* GELU activation function - Gaussian Error Linear Units<br> * GELU activation function - Gaussian Error Linear Units<br>
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a> * For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>

View File

@ -255,8 +255,6 @@ public class LegacyOpMapper {
return Abs.class; return Abs.class;
case 2: case 2:
return LogSoftMax.class; return LogSoftMax.class;
case 3:
return ELUDerivative.class;
case 4: case 4:
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class; return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
case 5: case 5:

View File

@ -881,7 +881,6 @@ public class OpValidation {
SoftmaxBp.class, SoftmaxBp.class,
CubeDerivative.class, CubeDerivative.class,
ELUDerivative.class,
GELUDerivative.class, GELUDerivative.class,
PreciseGELUDerivative.class, PreciseGELUDerivative.class,
HardSigmoidDerivative.class, HardSigmoidDerivative.class,

View File

@ -422,7 +422,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class, org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,

View File

@ -23,7 +23,6 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
@ -75,20 +74,8 @@ public class ActivationELU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
// no support in ELU native to override alpha Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
if (alpha != 1.00) { return new Pair<>(in, null);
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup()));
dLdz.muli(alpha);
BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
else {
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
return new Pair<>(in, null);
}
} }
@Override @Override

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
*
* Derivative of ELU: Exponential Linear Unit (alpha=1.0)<br>
* Introduced in paper:<br>
* Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)<br>
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
*
* @deprecated Use {@link EluBp}
*
* @author Alex Black
*/
@Deprecated
public class ELUDerivative extends BaseTransformStrictOp {
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public ELUDerivative() {
}
public ELUDerivative(INDArray x, INDArray z) {
super(x, z);
}
public ELUDerivative(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 3;
}
@Override
public String opName() {
return "eluderivative";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = sameDiff.zerosLike(arg());
return Collections.singletonList(ret);
}
}

View File

@ -37,8 +37,13 @@ public class EluBp extends DynamicCustomOp {
super(sd, new SDVariable[]{input, gradient}); super(sd, new SDVariable[]{input, gradient});
} }
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
this(input, gradient, output, 1.0);
}
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
super(new INDArray[]{input, gradient}, wrapOrNull(output)); super(new INDArray[]{input, gradient}, wrapOrNull(output));
addTArgument(alpha);
} }
@Override @Override

View File

@ -71,11 +71,6 @@ public class ELU extends DynamicCustomOp {
return "Elu"; return "Elu";
} }
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
}
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
//ELU: e^x-1 if x<0, x otherwise //ELU: e^x-1 if x<0, x otherwise

View File

@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
@ -441,13 +441,13 @@ public class Transforms {
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0]; return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
} }
public static INDArray eluDerivative(INDArray arr) { public static INDArray eluDerivative(INDArray arr, INDArray grad) {
return eluDerivative(arr, true); return eluDerivative(arr, grad,true);
} }
public static INDArray eluDerivative(INDArray in, boolean copy) { public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) {
return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in))); return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0];
} }

View File

@ -12859,7 +12859,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/** /**
* This is Concatenated RELU implementation. * This is Concatenated RELU implementation.
* What happens inside: RELU(Concat((x, -x, {-1}))) * What happens inside: RELU(Concat((x, -x, {-1})))
* *
* PLEASE NOTE: Concatenation will double amount of features available in input * PLEASE NOTE: Concatenation will double amount of features available in input
*/ */
// #if NOT_EXCLUDED(OP_crelu) // #if NOT_EXCLUDED(OP_crelu)

View File

@ -52,7 +52,8 @@ public class TFGraphTestList {
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
public static String[] modelNames = new String[]{ public static String[] modelNames = new String[]{
"cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" // "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
"cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu"
}; };
@After @After

View File

@ -305,44 +305,6 @@ public class DerivativeTests extends BaseNd4jTest {
} }
} }
@Test
public void testELUDerivative() {
//f(x) = x if x>=0
//f(x) = 1.0*(exp(x)-1)
INDArray z = Nd4j.zeros(100);
double[] out = new double[100];
double[] outDeriv = new double[100];
for (int i = 0; i < 100; i++) {
double x = 0.1 * (i - 50);
z.putScalar(i, x);
if (x >= 0) {
out[i] = x;
outDeriv[i] = 1.0;
} else {
out[i] = FastMath.exp(x) - 1.0;
outDeriv[i] = FastMath.exp(x);
}
}
INDArray act = Transforms.elu(z, true);
INDArray actDeriv = Nd4j.getExecutioner().exec(new ELUDerivative(z.dup()));
System.out.println(act);
for (int i = 0; i < 100; i++) {
double relError1 = Math.abs(out[i] - act.getDouble(i)) / (Math.abs(out[i]) + Math.abs(act.getDouble(i)));
if (out[i] == 0.0 && act.getDouble(i) == 0.0)
relError1 = 0.0;
double relError2 = Math.abs(outDeriv[i] - actDeriv.getDouble(i))
/ (Math.abs(outDeriv[i]) + Math.abs(actDeriv.getDouble(i)));
if (outDeriv[i] == 0.0 && actDeriv.getDouble(i) == 0.0)
relError2 = 0.0;
assertTrue(relError1 < REL_ERROR_TOLERANCE);
assertTrue(relError2 < REL_ERROR_TOLERANCE);
}
}
@Override @Override
public char ordering() { public char ordering() {
return 'f'; return 'f';