DL4J SameDiff loss function (#251)
* Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * SameDiffLoss draft * very very draft * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * temporary commit for clarification Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v2 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * very very draft * temporary commit for clarification Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v2 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * temporary commit for clarification v3 Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com> * SDLoss after requested changes but with questions in comments Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * added requested changes * small fixes Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Javadoc Signed-off-by: Alex Black <blacka101@gmail.com> * Test fix Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Andrii Tuzhykov <andrew@unrealists.com> Co-authored-by: atuzhykov <andrewtuzhukov@gmail.com> Co-authored-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>master
parent
5fbb04531d
commit
18d4eaa68d
|
@ -19,6 +19,8 @@ package org.deeplearning4j.gradientcheck;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMAE;
|
||||||
|
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMSE;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -83,7 +85,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
|
||||||
new LossMultiLabel(), new LossWasserstein(),
|
new LossMultiLabel(), new LossWasserstein(),
|
||||||
new LossSparseMCXENT()
|
new LossSparseMCXENT(),
|
||||||
|
new SDLossMAE(), new SDLossMSE()
|
||||||
};
|
};
|
||||||
|
|
||||||
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
|
||||||
|
@ -119,6 +122,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
|
||||||
Activation.IDENTITY, // Wasserstein
|
Activation.IDENTITY, // Wasserstein
|
||||||
Activation.SOFTMAX, //sparse MCXENT
|
Activation.SOFTMAX, //sparse MCXENT
|
||||||
|
Activation.SOFTMAX, // SDLossMAE
|
||||||
|
Activation.SIGMOID, // SDLossMAE
|
||||||
|
Activation.TANH, // SDLossMAE
|
||||||
|
Activation.SOFTMAX, // SDLossMSE
|
||||||
|
Activation.SIGMOID, // SDLossMSE
|
||||||
|
Activation.TANH //SDLossMSE
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] nOut = new int[] {1, //xent
|
int[] nOut = new int[] {1, //xent
|
||||||
|
@ -154,6 +163,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
10, // MultiLabel
|
10, // MultiLabel
|
||||||
2, // Wasserstein
|
2, // Wasserstein
|
||||||
4, //sparse MCXENT
|
4, //sparse MCXENT
|
||||||
|
3, // SDLossMAE
|
||||||
|
3, // SDLossMAE
|
||||||
|
3, // SDLossMAE
|
||||||
|
3, // SDLossMSE
|
||||||
|
3, // SDLossMSE
|
||||||
|
3, // SDLossMSE
|
||||||
};
|
};
|
||||||
|
|
||||||
int[] minibatchSizes = new int[] {1, 3};
|
int[] minibatchSizes = new int[] {1, 3};
|
||||||
|
@ -520,6 +535,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
break;
|
break;
|
||||||
case "LossMAE":
|
case "LossMAE":
|
||||||
case "LossMSE":
|
case "LossMSE":
|
||||||
|
case "SDLossMAE":
|
||||||
|
case "SDLossMSE":
|
||||||
case "LossL1":
|
case "LossL1":
|
||||||
case "LossL2":
|
case "LossL2":
|
||||||
ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1);
|
ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1);
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.gradientcheck.sdlosscustom;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
public class SDLossMAE extends SameDiffLoss {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||||
|
return sd.math.abs(labels.sub(layerInput)).mean(1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.gradientcheck.sdlosscustom;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.lossfunctions.*;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
public class SDLossMSE extends SameDiffLoss {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||||
|
return labels.squaredDifference(layerInput).mean(1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -111,7 +111,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
|
||||||
assertEquals(dtype, net.params().dataType());
|
assertEquals(dtype, net.params().dataType());
|
||||||
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
boolean eq = outExp.equalsWithEps(outAct, 0.01);
|
||||||
assertTrue(outExp + " vs " + outAct, eq);
|
assertTrue("Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct, eq);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
@ -348,8 +349,10 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
||||||
if (dimensions == null || dimensions.length == 0)
|
if (dimensions == null || dimensions.length == 0)
|
||||||
dimensions = new int[]{Integer.MAX_VALUE};
|
dimensions = new int[]{Integer.MAX_VALUE};
|
||||||
|
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
|
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public INDArray dimensions() {
|
public INDArray dimensions() {
|
||||||
return dimensionz;
|
return dimensionz;
|
||||||
|
|
|
@ -0,0 +1,186 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.lossfunctions;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SameDiff loss function.
|
||||||
|
*
|
||||||
|
* This class can be extended to create Deeplearning4j loss functions by defining one single method only:
|
||||||
|
* {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a
|
||||||
|
* <i>per example</i> basis - i.e., the output should be an array with shape [minibatch].<br>
|
||||||
|
* <br>
|
||||||
|
* For example, the mean squared error (MSE) loss function can be defined using:<br>
|
||||||
|
* {@code return labels.squaredDifference(layerInput).mean(1);}
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public abstract class SameDiffLoss implements ILossFunction {
|
||||||
|
protected transient SameDiff sd;
|
||||||
|
protected transient SDVariable scoreVariable;
|
||||||
|
|
||||||
|
protected SameDiffLoss() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define the loss function.<br>
|
||||||
|
* <b>NOTE</b>: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i]
|
||||||
|
* is the score for the ith minibatch
|
||||||
|
*
|
||||||
|
* @param sd SameDiff instance to define the loss on
|
||||||
|
* @param layerInput Input to the SameDiff loss function
|
||||||
|
* @param labels Labels placeholder
|
||||||
|
* @return The score on a per example basis (SDVariable with shape [minibatch])
|
||||||
|
*/
|
||||||
|
public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels);
|
||||||
|
|
||||||
|
protected void createSameDiffInstance(DataType dataType){
|
||||||
|
sd = SameDiff.create();
|
||||||
|
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
|
||||||
|
SDVariable labels = sd.placeHolder("labels", dataType, -1);
|
||||||
|
scoreVariable = this.defineLoss(sd, layerInput, labels);
|
||||||
|
sd.createGradFunction("layerInput");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the score (loss function value) for the given inputs.
|
||||||
|
*
|
||||||
|
* @param labels Label/expected preOutput
|
||||||
|
* @param preOutput Output of the model (neural network)
|
||||||
|
* @param activationFn Activation function that should be applied to preOutput
|
||||||
|
* @param mask Mask array; may be null
|
||||||
|
* @param average Whether the score should be averaged (divided by number of rows in labels/preOutput) or not @return Loss function value
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
|
||||||
|
if(sd == null){
|
||||||
|
createSameDiffInstance(preOutput.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask);
|
||||||
|
|
||||||
|
double score = scoreArr.sumNumber().doubleValue();
|
||||||
|
if (average) {
|
||||||
|
score /= scoreArr.size(0);
|
||||||
|
}
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the score (loss function value) for each example individually.
|
||||||
|
* For input [numExamples,nOut] returns scores as a column vector: [numExamples,1]
|
||||||
|
*
|
||||||
|
* @param labels Labels/expected output
|
||||||
|
* @param preOutput Output of the model (neural network)
|
||||||
|
* @param activationFn Activation function that should be applied to preOutput
|
||||||
|
* @param mask @return Loss function value for each example; column vector
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
|
if(sd == null){
|
||||||
|
createSameDiffInstance(preOutput.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1));
|
||||||
|
|
||||||
|
INDArray output = activationFn.getActivation(preOutput.dup(), true);
|
||||||
|
|
||||||
|
Map<String, INDArray> m = new HashMap<>();
|
||||||
|
m.put("labels", labels);
|
||||||
|
m.put("layerInput", output);
|
||||||
|
|
||||||
|
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name());
|
||||||
|
|
||||||
|
if (mask != null) {
|
||||||
|
LossUtil.applyMask(scoreArr, mask);
|
||||||
|
}
|
||||||
|
return scoreArr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the gradient of the loss function with respect to the inputs: dL/dOutput
|
||||||
|
*
|
||||||
|
* @param labels Label/expected output
|
||||||
|
* @param preOutput Output of the model (neural network), before the activation function is applied
|
||||||
|
* @param activationFn Activation function that should be applied to preOutput
|
||||||
|
* @param mask Mask array; may be null
|
||||||
|
* @return Gradient dL/dPreOut
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
|
||||||
|
if(sd == null){
|
||||||
|
createSameDiffInstance(preOutput.dataType());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Map<String, INDArray> m = new HashMap<>();
|
||||||
|
INDArray output = activationFn.getActivation(preOutput.dup(), true);
|
||||||
|
m.put("labels", labels);
|
||||||
|
m.put("layerInput", output);
|
||||||
|
|
||||||
|
Map<String, INDArray> grads = sd.calculateGradients(m, "layerInput");
|
||||||
|
|
||||||
|
INDArray gradAtActivationOutput = grads.get("layerInput");
|
||||||
|
INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst();
|
||||||
|
|
||||||
|
if (mask != null) {
|
||||||
|
LossUtil.applyMask(gradAtInput, mask);
|
||||||
|
}
|
||||||
|
return gradAtInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute both the score (loss function value) and gradient. This is equivalent to calling {@link #computeScore(INDArray, INDArray, IActivation, INDArray, boolean)}
|
||||||
|
* and {@link #computeGradient(INDArray, INDArray, IActivation, INDArray)} individually
|
||||||
|
*
|
||||||
|
* @param labels Label/expected output
|
||||||
|
* @param preOutput Output of the model (neural network)
|
||||||
|
* @param activationFn Activation function that should be applied to preOutput
|
||||||
|
* @param mask Mask array; may be null
|
||||||
|
* @param average Whether the score should be averaged (divided by number of rows in labels/output) or not
|
||||||
|
* @return The score (loss function value) and gradient
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
|
||||||
|
INDArray mask, boolean average) {
|
||||||
|
|
||||||
|
Pair<Double, INDArray> GradientAndScore = new Pair<>();
|
||||||
|
GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average));
|
||||||
|
GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask));
|
||||||
|
|
||||||
|
return GradientAndScore;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String name() {
|
||||||
|
return getClass().getSimpleName();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -683,7 +683,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
||||||
|
|
||||||
|
|
||||||
workspace.initializeWorkspace();
|
workspace.initializeWorkspace();
|
||||||
long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType());
|
long reqMemory = 11 * Nd4j.sizeOfDataType(arrayCold.dataType());
|
||||||
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());
|
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue