Add SameDiff exec debugging listener + few fixes (#104)

* First pass on SameDiff op exec debug listener

Signed-off-by: Alex Black <blacka101@gmail.com>

* #7555 DL4J helpers - don't fall back on builtin for op profiler exceptions

Signed-off-by: Alex Black <blacka101@gmail.com>

* Exec debugging listener + fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix import counts for TF ops in OpValidationSuite

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix bad DL4J test configuration

Signed-off-by: Alex Black <blacka101@gmail.com>

* Exec debugging listener polish

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Another fix

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-07 17:18:29 +10:00 committed by GitHub
parent ba1d1b160b
commit edb71bf46f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 369 additions and 94 deletions

View File

@ -165,35 +165,13 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(2))
.list()
.layer(0,
new DenseLayer.Builder()
.nIn(4)
.nOut(3)
.build()
)
.layer(1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
)
)
.layer(2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build()
)
).layer(3,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.nIn(2)
.nOut(1)
.build()
)
)
.layer(new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf1);
@ -238,70 +216,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.seed(12345)
.graphBuilder()
.addInputs("input")
.addLayer(initialLayer,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
"input"
)
.addLayer(frozenBranchUnfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(3)
.build(),
initialLayer
)
.addLayer(frozenBranchFrozenLayer1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
),
frozenBranchUnfrozenLayer0
)
.addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build()
),
frozenBranchFrozenLayer1
)
.addLayer(unfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
initialLayer
)
.addLayer(unfrozenLayer1,
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build(),
unfrozenLayer0
)
.addLayer(unfrozenBranch2,
new DenseLayer.Builder()
.nIn(2)
.nOut(1)
.build(),
unfrozenLayer1
)
.addVertex("merge",
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.nIn(3)
.nOut(1)
.build()
),
"merge"
)
.addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0, new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1, new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2, new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput)
.build();

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -170,6 +171,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
pad, biasGradView, weightGradView, afn,
layerConf().getCudnnAlgoMode(), layerConf().getCudnnBwdFilterAlgo(), layerConf().getCudnnBwdDataAlgo(),
convolutionMode, dilation, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(e.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
@ -359,6 +362,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
try {
ret = helper.preOutput(input, weights, bias, kernel, strides, pad, layerConf().getCudnnAlgoMode(),
layerConf().getCudnnFwdAlgo(), convolutionMode, dilation, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(e.getMessage() != null && e.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.OneTimeLogger;
@ -131,6 +132,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
try{
ret = helper.backpropGradient(input, epsilon, kernel, strides, pad,
layerConf().getPoolingType(), convolutionMode, dilation, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(e.getMessage() != null && e.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
@ -256,6 +259,8 @@ public class SubsamplingLayer extends AbstractLayer<org.deeplearning4j.nn.conf.l
try {
ret = helper.activate(input, training, kernel, strides, pad, layerConf().getPoolingType(),
convolutionMode, dilation, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Exception e){
if(layerConf().isCudnnAllowFallback()){
helperCountFail++;

View File

@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldSubOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
@ -156,6 +157,8 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
try {
ret = helper.backpropGradient(in, eps, ArrayUtil.toInts(shape), gamma, dGammaView, dBetaView,
layerConf.getEps(), workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
@ -447,6 +450,8 @@ public class BatchNormalization extends BaseLayer<org.deeplearning4j.nn.conf.lay
ret = helper.preOutput(in, training == TrainingMode.TRAIN, ArrayUtil.toInts(shape), gamma, beta, globalMeanView,
globalVarView, decay, layerConf.getEps(), workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t) {
if(t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLocalResponseNormalizationHelpe
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -141,6 +142,8 @@ public class LocalResponseNormalization
Pair<Gradient, INDArray> ret = null;
try {
ret = helper.backpropGradient(input, epsilon, k, n, alpha, beta, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation
@ -206,6 +209,8 @@ public class LocalResponseNormalization
INDArray activations = null;
try {
activations = helper.activate(input, training, k, n, alpha, beta, workspaceMgr);
} catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging
} catch (Throwable t){
if(t.getMessage().contains("Failed to allocate")){
//This is a memory exception - don't fallback to built-in implementation

View File

@ -0,0 +1,257 @@
package org.nd4j.autodiff.listeners.debugging;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import java.util.Arrays;
/**
* A listener that logs operation execution for debugging purposes.
* 3 modes are supported:<br><br>
* <b>OPS_ONLY</b>: Only the operations names are printed. For example:<br>
* {@code (iter=0,epoch=0,op=1) org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp}<br>
* <b>SHAPES_ONLY</b>: Print the operation class, shape info (for inputs/output arrays) as well as any arguments - iArgs, bArgs, tArgs. For example:<br>
* <pre>{@code
* (iter=1,epoch=0,op=3) org.nd4j.linalg.api.ops.impl.loss.LogLoss
* iArgs=[3]
* tArgs=[1.0E-7]
* Input[0]=Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [1,2], Stride: [1,1]
* Input[1]=Rank: 0, DataType: FLOAT, Offset: 0, Order: c, Shape: [], Stride: []
* Input[2]=Rank: 2, DataType: FLOAT, Offset: 0, Order: c, Shape: [1,2], Stride: [1,1]
* Outputs[0]=Rank: 0, DataType: FLOAT, Offset: 0, Order: c, Shape: [], Stride: []
* }
* </pre>
* <b>REPRODUCE</b>: Print runnable Java code that should reproduce that op execution (other than perhaps exact input/output strides). For example:<br>
* <pre>{@code
* (iter=2,epoch=0,op=1) org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp
* DynamicCustomOp op = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp();
* INDArray[] inputs = new INDArray[2];
* inputs[0] = Nd4j.createFromArray(1.5253239f, 0.8733858f).reshape(1, 2);
* inputs[1] = Nd4j.createFromArray(0.483428f, 0.86025196f).reshape(1, 2);
* op.addInputArgument(inputs);
* INDArray[] outputs = new INDArray[1];
* outputs[0] = Nd4j.createFromArray(2.012087f, 1.7303026f).reshape(1, 2);
* op.addOutputArgument(outputs);
* Nd4j.exec(op);
* }
* </pre>
*
* @author Alex Black
*/
public class ExecDebuggingListener extends BaseListener {
public enum PrintMode {OPS_ONLY, SHAPES_ONLY, REPRODUCE}
private final PrintMode printMode;
private final int maxIterations;
private final boolean logIter;
private long printIterations = 0;
private int lastIter = -1;
private int stepThisIter = 0;
/**
* @param printMode Print mode, see {@link PrintMode}
* @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations"
* @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output
*/
public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){
this.printMode = printMode;
this.maxIterations = maxIterations;
this.logIter = logIter;
}
@Override
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
if(lastIter != at.iteration()){
lastIter = at.iteration();
stepThisIter = 0;
printIterations++;
}
if(maxIterations > 0 && printIterations > maxIterations){
return;
}
StringBuilder sb = new StringBuilder();
if(logIter){
sb.append("(iter=").append(at.iteration())
.append(",epoch=").append(at.epoch())
.append(",");
}
sb.append("op=").append(stepThisIter++)
.append(logIter ? ") " : " - ");
DifferentialFunction df = op.getOp();
sb.append(op.getOp().getClass().getName());
CustomOp co = df instanceof CustomOp ? (CustomOp) df : null;
Op lOp = df instanceof Op ? (Op) df : null;
if(printMode == PrintMode.OPS_ONLY){
sb.append("\n");
} else if(printMode == PrintMode.SHAPES_ONLY){
if(co != null){
if(co.iArgs() != null && co.iArgs().length > 0) {
sb.append("\n\tiArgs=").append(Arrays.toString(co.iArgs()));
}
if(co.bArgs() != null && co.bArgs().length > 0) {
sb.append("\n\tbArgs=").append(Arrays.toString(co.bArgs()));
}
if(co.tArgs() != null && co.tArgs().length > 0) {
sb.append("\n\ttArgs=").append(Arrays.toString(co.tArgs()));
}
INDArray[] inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments();
if(inputs != null ) {
for (int i = 0; i < inputs.length; i++) {
sb.append("\n\tInput[").append(i).append("]=").append(inputs[i].shapeInfoToString());
}
}
if(outputs != null ) {
for (int i = 0; i < outputs.length; i++) {
sb.append("\n\tOutputs[").append(i).append("]=").append(outputs[i].shapeInfoToString());
}
}
} else {
if(lOp.x() != null) {
sb.append("\n\tx: ").append(lOp.x().shapeInfoToString());
}
if(lOp.y() != null) {
sb.append("\n\ty: ").append(lOp.y().shapeInfoToString());
}
if(lOp.z() != null) {
sb.append("\n\tz: ").append(lOp.z().shapeInfoToString());
}
if(lOp instanceof ScalarOp){
INDArray scalar = ((ScalarOp)lOp).scalar();
if(scalar != null){
sb.append("\n\tscalar: ").append(scalar.shapeInfoToString());
}
}
}
sb.append("\n");
} else if(printMode == PrintMode.REPRODUCE){
sb.append("\n");
if(co != null){
sb.append("DynamicCustomOp op = new ").append(co.getClass().getName()).append("();\n");
if(co.iArgs() != null && co.iArgs().length > 0 ){
sb.append("op.addIArgument(").append(Arrays.toString(co.iArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
}
if(co.bArgs() != null && co.bArgs().length > 0 ){
sb.append("op.addBArgument(").append(Arrays.toString(co.bArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
}
if(co.tArgs() != null && co.tArgs().length > 0 ){
sb.append("op.addTArgument(").append(Arrays.toString(co.tArgs()).replaceAll("[\\[\\]]", "")).append(");\n");
}
INDArray[] inputs = co.inputArguments();
INDArray[] outputs = co.outputArguments();
if(inputs != null ) {
sb.append("INDArray[] inputs = new INDArray[").append(inputs.length).append("];\n");
for (int i = 0; i < inputs.length; i++) {
sb.append("inputs[").append(i).append("] = ");
sb.append(createString(inputs[i]))
.append(";\n");
}
sb.append("op.addInputArgument(inputs);\n");
}
if(outputs != null ) {
sb.append("INDArray[] outputs = new INDArray[").append(outputs.length).append("];\n");
for (int i = 0; i < outputs.length; i++) {
sb.append("outputs[").append(i).append("] = ");
sb.append(createString(outputs[i]))
.append(";\n");
}
sb.append("op.addOutputArgument(outputs);\n");
}
} else {
sb.append("Op op = new ").append(op.getClass().getName()).append("();\n");
if(lOp.x() != null) {
sb.append("op.setX(").append(createString(lOp.x())).append(");\n");
}
if(lOp.y() != null) {
sb.append("op.setY(").append(createString(lOp.y())).append(");\n");
}
if(lOp.z() != null) {
sb.append("op.setZ").append(createString(lOp.z())).append(");\n");
}
if(lOp instanceof ScalarOp){
INDArray scalar = ((ScalarOp)lOp).scalar();
if(scalar != null){
sb.append("((ScalarOp)op).setScalar(").append(createString(scalar)).append(");\n");
}
}
}
sb.append("Nd4j.exec(op);\n");
}
System.out.print(sb.toString());
}
private static String createString(INDArray arr){
StringBuilder sb = new StringBuilder();
if(arr.isEmpty()){
sb.append("Nd4j.empty(DataType.").append(arr.dataType()).append(");");
} else {
sb.append("Nd4j.createFromArray(");
DataType dt = arr.dataType();
switch (dt){
case DOUBLE:
double[] dArr = arr.dup().data().asDouble();
sb.append(Arrays.toString(dArr).replaceAll("[\\[\\]]", ""));
break;
case FLOAT:
case HALF:
case BFLOAT16:
float[] fArr = arr.dup().data().asFloat();
sb.append(Arrays.toString(fArr)
.replaceAll(",", "f,")
.replaceAll("]", "f")
.replaceAll("[\\[\\]]", ""));
break;
case LONG:
case UINT32:
case UINT64:
long[] lArr = arr.dup().data().asLong();
sb.append(Arrays.toString(lArr)
.replaceAll(",", "L,")
.replaceAll("]", "L")
.replaceAll("[\\[\\]]", ""));
break;
case INT:
case SHORT:
case UBYTE:
case BYTE:
case UINT16:
case BOOL:
int[] iArr = arr.dup().data().asInt();
sb.append(Arrays.toString(iArr).replaceAll("[\\[\\]]", ""));
break;
case UTF8:
break;
case COMPRESSED:
case UNKNOWN:
break;
}
sb.append(").reshape(").append(Arrays.toString(arr.shape()).replaceAll("[\\[\\]]", ""))
.append(")");
if(dt == DataType.HALF || dt == DataType.BFLOAT16 || dt == DataType.UINT32 || dt == DataType.UINT64 ||
dt == DataType.SHORT || dt == DataType.UBYTE || dt == DataType.BYTE || dt == DataType.UINT16 || dt == DataType.BOOL){
sb.append(".cast(DataType.").append(arr.dataType()).append(")");
}
}
return sb.toString();
}
}

View File

@ -235,7 +235,18 @@ public class DifferentialFunctionClassHolder {
//log.debug("Missing " + set.size() + " ops!");
countTotalTfOps = tensorflowOpDescriptors.size();
countTotalMappedOps = nodeConverters.size();
//Work out total number of TF ops mapped
Set<String> tfMappedOps = new HashSet<>();
for(DifferentialFunction df : nodeConverters.values()){
try{
String[] tfNames = df.tensorflowNames();
Collections.addAll(tfMappedOps, tfNames);
} catch (NoOpNameFoundException e){
//Ignore
}
}
countTotalMappedOps = tfMappedOps.size();
//Get custom ops - map from hash to class
Map<String,CustomOpDescriptor> descriptorMap = Nd4j.getExecutioner().getCustomOperations();

View File

@ -0,0 +1,61 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.Test;
import org.nd4j.autodiff.listeners.debugging.ExecDebuggingListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
public class ExecDebuggingListenerTest extends BaseNd4jTest {
public ExecDebuggingListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testExecDebugListener(){
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 3);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, 1, 2);
SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 3, 2));
SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 2));
SDVariable sm = sd.nn.softmax("softmax", in.mmul(w).add(b));
SDVariable loss = sd.loss.logLoss("loss", label, sm);
INDArray i = Nd4j.rand(DataType.FLOAT, 1, 3);
INDArray l = Nd4j.rand(DataType.FLOAT, 1, 2);
sd.setTrainingConfig(TrainingConfig.builder()
.dataSetFeatureMapping("in")
.dataSetLabelMapping("label")
.updater(new Adam(0.001))
.build());
for(ExecDebuggingListener.PrintMode pm : ExecDebuggingListener.PrintMode.values()){
sd.setListeners(new ExecDebuggingListener(pm, -1, true));
// sd.output(m, "softmax");
sd.fit(new DataSet(i, l));
System.out.println("\n\n\n");
}
}
@Override
public char ordering() {
return 'c';
}
}