DL4J SameDiff loss function (#251)
* Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * SameDiffLoss draft * very very draft * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * temporary commit for clarification Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v2 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * very very draft * temporary commit for clarification Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v2 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * SDLoss after requested changes but with questions in comments Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * added requested changes * small fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Javadoc Signed-off-by: Alex Black <blacka101@gmail.com> * Test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Andrii Tuzhykov <andrew@unrealists.com> Co-authored-by: atuzhykov <andrewtuzhukov@gmail.com> Co-authored-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>master
parent
5fbb04531d
commit
18d4eaa68d
|
@ -19,6 +19,8 @@ package org.deeplearning4j.gradientcheck;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMAE;
|
||||
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMSE;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -83,7 +85,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||
new LossMultiLabel(), new LossWasserstein(),
|
||||
new LossSparseMCXENT()
|
||||
new LossSparseMCXENT(),
|
||||
new SDLossMAE(), new SDLossMSE()
|
||||
};
|
||||
|
||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||
|
@ -119,6 +122,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
||||
Activation.IDENTITY, // Wasserstein
|
||||
Activation.SOFTMAX, //sparse MCXENT
|
||||
Activation.SOFTMAX, // SDLossMAE
|
||||
Activation.SIGMOID, // SDLossMAE
|
||||
Activation.TANH, // SDLossMAE
|
||||
Activation.SOFTMAX, // SDLossMSE
|
||||
Activation.SIGMOID, // SDLossMSE
|
||||
Activation.TANH //SDLossMSE
|
||||
};
|
||||
|
||||
int[] nOut = new int[] {1, //xent
|
||||
|
@ -154,6 +163,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
10, // MultiLabel
|
||||
2, // Wasserstein
|
||||
4, //sparse MCXENT
|
||||
3, // SDLossMAE
|
||||
3, // SDLossMAE
|
||||
3, // SDLossMAE
|
||||
3, // SDLossMSE
|
||||
3, // SDLossMSE
|
||||
3, // SDLossMSE
|
||||
};
|
||||
|
||||
int[] minibatchSizes = new int[] {1, 3};
|
||||
|
@ -520,6 +535,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
|||
break;
|
||||
case "LossMAE":
|
||||
case "LossMSE":
|
||||
case "SDLossMAE":
|
||||
case "SDLossMSE":
|
||||
case "LossL1":
|
||||
case "LossL2":
|
||||
ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1);
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.gradientcheck.sdlosscustom;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class SDLossMAE extends SameDiffLoss {
|
||||
|
||||
@Override
|
||||
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||
return sd.math.abs(labels.sub(layerInput)).mean(1);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.gradientcheck.sdlosscustom;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.lossfunctions.*;
|
||||
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
public class SDLossMSE extends SameDiffLoss {
|
||||
|
||||
@Override
|
||||
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||
return labels.squaredDifference(layerInput).mean(1);
|
||||
}
|
||||
}
|
|
@ -111,7 +111,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
|||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||
assertEquals(dtype, net.params().dataType());
|
||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||
assertTrue(outExp + " vs " + outAct, eq);
|
||||
assertTrue("Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct, eq);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -348,7 +349,9 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
if (dimensions == null || dimensions.length == 0)
|
||||
dimensions = new int[]{Integer.MAX_VALUE};
|
||||
|
||||
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
|
||||
}
|
||||
}
|
||||
|
||||
public INDArray dimensions() {
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.lossfunctions;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* SameDiff loss function.
|
||||
*
|
||||
* This class can be extended to create Deeplearning4j loss functions by defining one single method only:
|
||||
* {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a
|
||||
* <i>per example</i> basis - i.e., the output should be an array with shape [minibatch].<br>
|
||||
* <br>
|
||||
* For example, the mean squared error (MSE) loss function can be defined using:<br>
|
||||
* {@code return labels.squaredDifference(layerInput).mean(1);}
|
||||
*
|
||||
*/
|
||||
public abstract class SameDiffLoss implements ILossFunction {
|
||||
protected transient SameDiff sd;
|
||||
protected transient SDVariable scoreVariable;
|
||||
|
||||
protected SameDiffLoss() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the loss function.<br>
|
||||
* <b>NOTE</b>: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i]
|
||||
* is the score for the ith minibatch
|
||||
*
|
||||
* @param sd SameDiff instance to define the loss on
|
||||
* @param layerInput Input to the SameDiff loss function
|
||||
* @param labels Labels placeholder
|
||||
* @return The score on a per example basis (SDVariable with shape [minibatch])
|
||||
*/
|
||||
public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels);
|
||||
|
||||
protected void createSameDiffInstance(DataType dataType){
|
||||
sd = SameDiff.create();
|
||||
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
|
||||
SDVariable labels = sd.placeHolder("labels", dataType, -1);
|
||||
scoreVariable = this.defineLoss(sd, layerInput, labels);
|
||||
sd.createGradFunction("layerInput");
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the score (loss function value) for the given inputs.
|
||||
*
|
||||
* @param labels Label/expected preOutput
|
||||
* @param preOutput Output of the model (neural network)
|
||||
* @param activationFn Activation function that should be applied to preOutput
|
||||
* @param mask Mask array; may be null
|
||||
* @param average Whether the score should be averaged (divided by number of rows in labels/preOutput) or not @return Loss function value
|
||||
*/
|
||||
@Override
|
||||
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
|
||||
if(sd == null){
|
||||
createSameDiffInstance(preOutput.dataType());
|
||||
}
|
||||
|
||||
INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask);
|
||||
|
||||
double score = scoreArr.sumNumber().doubleValue();
|
||||
if (average) {
|
||||
score /= scoreArr.size(0);
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the score (loss function value) for each example individually.
|
||||
* For input [numExamples,nOut] returns scores as a column vector: [numExamples,1]
|
||||
*
|
||||
* @param labels Labels/expected output
|
||||
* @param preOutput Output of the model (neural network)
|
||||
* @param activationFn Activation function that should be applied to preOutput
|
||||
* @param mask @return Loss function value for each example; column vector
|
||||
*/
|
||||
@Override
|
||||
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
if(sd == null){
|
||||
createSameDiffInstance(preOutput.dataType());
|
||||
}
|
||||
|
||||
Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1));
|
||||
|
||||
INDArray output = activationFn.getActivation(preOutput.dup(), true);
|
||||
|
||||
Map<String, INDArray> m = new HashMap<>();
|
||||
m.put("labels", labels);
|
||||
m.put("layerInput", output);
|
||||
|
||||
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name());
|
||||
|
||||
if (mask != null) {
|
||||
LossUtil.applyMask(scoreArr, mask);
|
||||
}
|
||||
return scoreArr;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the gradient of the loss function with respect to the inputs: dL/dOutput
|
||||
*
|
||||
* @param labels Label/expected output
|
||||
* @param preOutput Output of the model (neural network), before the activation function is applied
|
||||
* @param activationFn Activation function that should be applied to preOutput
|
||||
* @param mask Mask array; may be null
|
||||
* @return Gradient dL/dPreOut
|
||||
*/
|
||||
@Override
|
||||
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||
if(sd == null){
|
||||
createSameDiffInstance(preOutput.dataType());
|
||||
}
|
||||
|
||||
|
||||
Map<String, INDArray> m = new HashMap<>();
|
||||
INDArray output = activationFn.getActivation(preOutput.dup(), true);
|
||||
m.put("labels", labels);
|
||||
m.put("layerInput", output);
|
||||
|
||||
Map<String, INDArray> grads = sd.calculateGradients(m, "layerInput");
|
||||
|
||||
INDArray gradAtActivationOutput = grads.get("layerInput");
|
||||
INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst();
|
||||
|
||||
if (mask != null) {
|
||||
LossUtil.applyMask(gradAtInput, mask);
|
||||
}
|
||||
return gradAtInput;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute both the score (loss function value) and gradient. This is equivalent to calling {@link #computeScore(INDArray, INDArray, IActivation, INDArray, boolean)}
|
||||
* and {@link #computeGradient(INDArray, INDArray, IActivation, INDArray)} individually
|
||||
*
|
||||
* @param labels Label/expected output
|
||||
* @param preOutput Output of the model (neural network)
|
||||
* @param activationFn Activation function that should be applied to preOutput
|
||||
* @param mask Mask array; may be null
|
||||
* @param average Whether the score should be averaged (divided by number of rows in labels/output) or not
|
||||
* @return The score (loss function value) and gradient
|
||||
*/
|
||||
@Override
|
||||
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
|
||||
INDArray mask, boolean average) {
|
||||
|
||||
Pair<Double, INDArray> GradientAndScore = new Pair<>();
|
||||
GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average));
|
||||
GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask));
|
||||
|
||||
return GradientAndScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String name() {
|
||||
return getClass().getSimpleName();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
@ -683,7 +683,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
|
||||
workspace.initializeWorkspace();
|
||||
long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType());
|
||||
long reqMemory = 11 * Nd4j.sizeOfDataType(arrayCold.dataType());
|
||||
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue