DL4J SameDiff loss function (#251)

* Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6

* SameDiffLoss draft

* very very draft

* Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6

* temporary commit for clarification

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v2

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v3

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v3

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6

* very very draft

* temporary commit for clarification

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v2

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v3

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* temporary commit for clarification v3

Signed-off-by: atuzhykov <andrewtuzhukov@gmail.com>

* SDLoss after requested changes but with questions in comments

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* added requested changes

* small fixes

Signed-off-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Javadoc

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fix

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Andrii Tuzhykov <andrew@unrealists.com>
Co-authored-by: atuzhykov <andrewtuzhukov@gmail.com>
Co-authored-by: Andrii Tuzhykov <andrewtuzhykov@gmail.com>
master
Alex Black 2020-04-17 19:47:57 +10:00 committed by GitHub
parent 5fbb04531d
commit 18d4eaa68d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 270 additions and 4 deletions

View File

@ -19,6 +19,8 @@ package org.deeplearning4j.gradientcheck;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMAE;
import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMSE;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -83,7 +85,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(),
new LossMultiLabel(), new LossWasserstein(),
new LossSparseMCXENT()
new LossSparseMCXENT(),
new SDLossMAE(), new SDLossMSE()
};
Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent
@ -119,6 +122,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper
Activation.IDENTITY, // Wasserstein
Activation.SOFTMAX, //sparse MCXENT
Activation.SOFTMAX, // SDLossMAE
Activation.SIGMOID, // SDLossMAE
Activation.TANH, // SDLossMAE
Activation.SOFTMAX, // SDLossMSE
Activation.SIGMOID, // SDLossMSE
Activation.TANH //SDLossMSE
};
int[] nOut = new int[] {1, //xent
@ -154,6 +163,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
10, // MultiLabel
2, // Wasserstein
4, //sparse MCXENT
3, // SDLossMAE
3, // SDLossMAE
3, // SDLossMAE
3, // SDLossMSE
3, // SDLossMSE
3, // SDLossMSE
};
int[] minibatchSizes = new int[] {1, 3};
@ -520,6 +535,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
break;
case "LossMAE":
case "LossMSE":
case "SDLossMAE":
case "SDLossMSE":
case "LossL1":
case "LossL2":
ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1);

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gradientcheck.sdlosscustom;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
@EqualsAndHashCode(callSuper = false)
public class SDLossMAE extends SameDiffLoss {
@Override
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
return sd.math.abs(labels.sub(layerInput)).mean(1);
}
}

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gradientcheck.sdlosscustom;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.lossfunctions.*;
@EqualsAndHashCode(callSuper = false)
public class SDLossMSE extends SameDiffLoss {
@Override
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
return labels.squaredDifference(layerInput).mean(1);
}
}

View File

@ -111,7 +111,7 @@ public class RegressionTest100b6 extends BaseDL4JTest {
assertEquals(dtype, net.getLayerWiseConfigurations().getDataType());
assertEquals(dtype, net.params().dataType());
boolean eq = outExp.equalsWithEps(outAct, 0.01);
assertTrue(outExp + " vs " + outAct, eq);
assertTrue("Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct, eq);
}
}

View File

@ -28,6 +28,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -348,7 +349,9 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
if (dimensions == null || dimensions.length == 0)
dimensions = new int[]{Integer.MAX_VALUE};
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
}
}
public INDArray dimensions() {

View File

@ -0,0 +1,186 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.lossfunctions;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.HashMap;
import java.util.Map;
/**
* SameDiff loss function.
*
* This class can be extended to create Deeplearning4j loss functions by defining one single method only:
* {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a
* <i>per example</i> basis - i.e., the output should be an array with shape [minibatch].<br>
* <br>
* For example, the mean squared error (MSE) loss function can be defined using:<br>
* {@code return labels.squaredDifference(layerInput).mean(1);}
*
*/
public abstract class SameDiffLoss implements ILossFunction {
protected transient SameDiff sd;
protected transient SDVariable scoreVariable;
protected SameDiffLoss() {
}
/**
* Define the loss function.<br>
* <b>NOTE</b>: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i]
* is the score for the ith minibatch
*
* @param sd SameDiff instance to define the loss on
* @param layerInput Input to the SameDiff loss function
* @param labels Labels placeholder
* @return The score on a per example basis (SDVariable with shape [minibatch])
*/
public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels);
protected void createSameDiffInstance(DataType dataType){
sd = SameDiff.create();
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
SDVariable labels = sd.placeHolder("labels", dataType, -1);
scoreVariable = this.defineLoss(sd, layerInput, labels);
sd.createGradFunction("layerInput");
}
/**
* Compute the score (loss function value) for the given inputs.
*
* @param labels Label/expected preOutput
* @param preOutput Output of the model (neural network)
* @param activationFn Activation function that should be applied to preOutput
* @param mask Mask array; may be null
* @param average Whether the score should be averaged (divided by number of rows in labels/preOutput) or not @return Loss function value
*/
@Override
public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
if(sd == null){
createSameDiffInstance(preOutput.dataType());
}
INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask);
double score = scoreArr.sumNumber().doubleValue();
if (average) {
score /= scoreArr.size(0);
}
return score;
}
/**
* Compute the score (loss function value) for each example individually.
* For input [numExamples,nOut] returns scores as a column vector: [numExamples,1]
*
* @param labels Labels/expected output
* @param preOutput Output of the model (neural network)
* @param activationFn Activation function that should be applied to preOutput
* @param mask @return Loss function value for each example; column vector
*/
@Override
public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if(sd == null){
createSameDiffInstance(preOutput.dataType());
}
Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1));
INDArray output = activationFn.getActivation(preOutput.dup(), true);
Map<String, INDArray> m = new HashMap<>();
m.put("labels", labels);
m.put("layerInput", output);
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name());
if (mask != null) {
LossUtil.applyMask(scoreArr, mask);
}
return scoreArr;
}
/**
* Compute the gradient of the loss function with respect to the inputs: dL/dOutput
*
* @param labels Label/expected output
* @param preOutput Output of the model (neural network), before the activation function is applied
* @param activationFn Activation function that should be applied to preOutput
* @param mask Mask array; may be null
* @return Gradient dL/dPreOut
*/
@Override
public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
if(sd == null){
createSameDiffInstance(preOutput.dataType());
}
Map<String, INDArray> m = new HashMap<>();
INDArray output = activationFn.getActivation(preOutput.dup(), true);
m.put("labels", labels);
m.put("layerInput", output);
Map<String, INDArray> grads = sd.calculateGradients(m, "layerInput");
INDArray gradAtActivationOutput = grads.get("layerInput");
INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst();
if (mask != null) {
LossUtil.applyMask(gradAtInput, mask);
}
return gradAtInput;
}
/**
* Compute both the score (loss function value) and gradient. This is equivalent to calling {@link #computeScore(INDArray, INDArray, IActivation, INDArray, boolean)}
* and {@link #computeGradient(INDArray, INDArray, IActivation, INDArray)} individually
*
* @param labels Label/expected output
* @param preOutput Output of the model (neural network)
* @param activationFn Activation function that should be applied to preOutput
* @param mask Mask array; may be null
* @param average Whether the score should be averaged (divided by number of rows in labels/output) or not
* @return The score (loss function value) and gradient
*/
@Override
public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
INDArray mask, boolean average) {
Pair<Double, INDArray> GradientAndScore = new Pair<>();
GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average));
GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask));
return GradientAndScore;
}
@Override
public String name() {
return getClass().getSimpleName();
}
}

View File

@ -683,7 +683,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
workspace.initializeWorkspace();
long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType());
long reqMemory = 11 * Nd4j.sizeOfDataType(arrayCold.dataType());
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());