[WIP] latest update (#8145)

* [WIP] maxpool2d_bp fix (#160)

* one test for maxpool2d_bp

Signed-off-by: raver119 <raver119@gmail.com>

* - maxpool2d_bp cuda fix for NaNs
- streamSync after each custom op execution

Signed-off-by: raver119 <raver119@gmail.com>

* MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Upgrade protobuf version (#162)

* First steps for protobuf version upgrade

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Phase 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update imports to shaded protobuf

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Version fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch to single execution for protobuf codegen to work around plugin bug

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Automatically delete old PB generated files after name change

Signed-off-by: Alex Black <blacka101@gmail.com>

* - string NDArray flat serde impl + tests (#163)

- string NDArray equalsTo impl

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of context variable

Signed-off-by: raver119 <raver119@gmail.com>

* lup context fix (#164)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 16:59:30 +03:00 committed by GitHub
parent 95b2686ce5
commit d871eab2e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 891 additions and 459 deletions

View File

@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length]; List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Throwable t = null;
try { try {
for (int i = 0; i <= stopIndex; i++) { for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]]; GraphVertex current = vertices[topologicalOrder[i]];
@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build(); .build();
if(detachedInputs){ if (detachedInputs) {
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be //Sometimes (like: external errors use cases) we don't want the activations/inputs to be
// in a workspace // in a workspace
workspaceMgr.setScopedOutFor(ArrayType.INPUT); workspaceMgr.setScopedOutFor(ArrayType.INPUT);
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS); workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
} else { } else {
//Don't leverage out of async MultiDataSetIterator workspaces //Don't leverage out of async MultiDataSetIterator workspaces
if(features[0].isAttached()){ if (features[0].isAttached()) {
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId()); workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
} }
} }
@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
if (ArrayUtils.contains(layerIndexes, vIdx)) { if (ArrayUtils.contains(layerIndexes, vIdx)) {
isRequiredOutput = true; isRequiredOutput = true;
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
//Place activations in user-specified workspace //Place activations in user-specified workspace
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS); origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Open the relevant workspace for the activations. //Open the relevant workspace for the activations.
//Note that this will be closed only once the current vertex's activations have been consumed //Note that this will be closed only once the current vertex's activations have been consumed
MemoryWorkspace wsActivations = null; MemoryWorkspace wsActivations = null;
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS); wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
openActivationsWorkspaces.put(wsActivations, workspaceMgr); openActivationsWorkspaces.put(wsActivations, workspaceMgr);
} }
@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace //Note that because we're opening activation workspaces not in any defined order (i.e., workspace
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we // use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
// close these workspaces, the "current" workspace may be set to the incorrect one // close these workspaces, the "current" workspace may be set to the incorrect one
if(wsActivations != null ) if (wsActivations != null)
wsActivations.setPreviousWorkspace(initialWorkspace); wsActivations.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx]; int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) { if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
if (closeAtEndIteraton[closeableAt] == null) { if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>(); closeAtEndIteraton[closeableAt] = new ArrayList<>();
} }
@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
out = features[vIdx]; out = features[vIdx];
} else { } else {
if(fwdPassType == FwdPassType.STANDARD){ if (fwdPassType == FwdPassType.STANDARD) {
//Standard feed-forward case //Standard feed-forward case
out = current.doForward(train, workspaceMgr); out = current.doForward(train, workspaceMgr);
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
if (current.hasLayer()) { if (current.hasLayer()) {
//Layer //Layer
INDArray input = current.getInputs()[0]; INDArray input = current.getInputs()[0];
Layer l = current.getLayer(); Layer l = current.getLayer();
if (l instanceof RecurrentLayer) { if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){ } else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying()); RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if (l instanceof MultiLayerNetwork) { } else if (l instanceof MultiLayerNetwork) {
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input)); out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)"); validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
} }
if(inputsTo != null) { //Output vertices may not input to any other vertices if (inputsTo != null) { //Output vertices may not input to any other vertices
for (VertexIndices v : inputsTo) { for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in //Note that we don't have to do anything special here: the activations are always detached in
// this method // this method
@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
} }
if(clearLayerInputs) { if (clearLayerInputs) {
current.clear(); current.clear();
} }
if(isRequiredOutput){ if (isRequiredOutput) {
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out; outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
if(origWSAct != null){ if (origWSAct != null) {
//Reset the configuration, as we may reuse this workspace manager... //Reset the configuration, as we may reuse this workspace manager...
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf); workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
} }
@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Close any activations workspaces that we no longer require //Close any activations workspaces that we no longer require
//Note that activations workspaces can be closed only once the corresponding output activations have //Note that activations workspaces can be closed only once the corresponding output activations have
// been fully consumed // been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
} }
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
//Though if stopIndex < numLayers, some might still be open //Though if stopIndex < numLayers, some might still be open
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Edge case here: seems that scoping out can increase the tagScope of the current WS //Edge case here: seems that scoping out can increase the tagScope of the current WS
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient //and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
// number of times to actually close it, in all cases // number of times to actually close it, in all cases
ws.close(); try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
try{ Throwable t = null;
for(int i=topologicalOrder.length-1; i>= 0; i--){ try {
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
boolean hitFrozen = false; boolean hitFrozen = false;
GraphVertex current = vertices[topologicalOrder[i]]; GraphVertex current = vertices[topologicalOrder[i]];
int vIdx = current.getVertexIndex(); int vIdx = current.getVertexIndex();
String vertexName = current.getVertexName(); String vertexName = current.getVertexName();
if(traceLog){ if (traceLog) {
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
} }
//FIXME: make the frozen vertex feature extraction more flexible //FIXME: make the frozen vertex feature extraction more flexible
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){ if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
hitFrozen = true; hitFrozen = true;
} }
if (current.isInputVertex() || hitFrozen){ if (current.isInputVertex() || hitFrozen) {
//Close any activation gradient workspaces that we no longer require //Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations //Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed // gradients have been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
wsActivationGrads.setPreviousWorkspace(initialWorkspace); wsActivationGrads.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx]; int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
if(closeableAt >= 0) { if (closeableAt >= 0) {
if (closeAtEndIteraton[closeableAt] == null) { if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>(); closeAtEndIteraton[closeableAt] = new ArrayList<>();
} }
@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
Pair<Gradient, INDArray[]> pair; Pair<Gradient, INDArray[]> pair;
INDArray[] epsilons; INDArray[] epsilons;
try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
pair = current.doBackward(truncatedBPTT, workspaceMgr); pair = current.doBackward(truncatedBPTT, workspaceMgr);
epsilons = pair.getSecond(); epsilons = pair.getSecond();
//Validate workspace location for the activation gradients: //Validate workspace location for the activation gradients:
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){ //validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
for (INDArray epsilon : epsilons) { for (INDArray epsilon : epsilons) {
if(epsilon != null) { if (epsilon != null) {
//May be null for EmbeddingLayer, etc //May be null for EmbeddingLayer, etc
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop"); validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
} }
@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
tempList.addFirst(new Triple<>(newName, entry.getValue(), tempList.addFirst(new Triple<>(newName, entry.getValue(),
g.flatteningOrderForVariable(origName))); g.flatteningOrderForVariable(origName)));
} }
for (Triple<String, INDArray, Character> t : tempList) for (Triple<String, INDArray, Character> triple : tempList)
gradients.addFirst(t); gradients.addFirst(triple);
} }
//Close any activation gradient workspaces that we no longer require //Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations //Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed // gradients have been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
closeAtEndIteraton[i] = null; closeAtEndIteraton[i] = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){ for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
ws.close(); try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }
//Now, add the gradients in the order we need them in for flattening (same as params order) //Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients); Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple<String, INDArray, Character> t : gradients) { for (Triple<String, INDArray, Character> tr : gradients) {
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird()); gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
} }
this.gradient = gradient; this.gradient = gradient;

View File

@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = 0; i <= layerIndex; i++) { for (int i = 0; i <= layerIndex; i++) {
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
if(traceLog){ if (traceLog) {
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
} }
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
//Hence: put inputs in working memory //Hence: put inputs in working memory
if(i == 0 && wsm != WorkspaceMode.NONE){ if (i == 0 && wsm != WorkspaceMode.NONE) {
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
} }
@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
temp.setPreviousWorkspace(initialWorkspace); temp.setPreviousWorkspace(initialWorkspace);
if(i == 0 && input.isAttached()){ if (i == 0 && input.isAttached()) {
//Don't leverage out of async DataSetIterator workspaces //Don't leverage out of async DataSetIterator workspaces
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
} }
@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)"); validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
} }
if ( i == layerIndex ) { if (i == layerIndex) {
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
//Place activations in user-specified workspace //Place activations in user-specified workspace
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration()); mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
} else { } else {
@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if(fwdPassType == FwdPassType.STANDARD){ if (fwdPassType == FwdPassType.STANDARD) {
//Standard feed-forward case //Standard feed-forward case
input = layers[i].activate(input, train, mgr); input = layers[i].activate(input, train, mgr);
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
//rnnTimeStep case //rnnTimeStep case
if (layers[i] instanceof RecurrentLayer) { if (layers[i] instanceof RecurrentLayer) {
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()); RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying());
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr); input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
} else if (layers[i] instanceof MultiLayerNetwork) { } else if (layers[i] instanceof MultiLayerNetwork) {
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input)); input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
//Validation: Exception if invalid (bad layer implementation) //Validation: Exception if invalid (bad layer implementation)
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)"); validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
if(wsActCloseNext != null){ if (wsActCloseNext != null) {
wsActCloseNext.close(); wsActCloseNext.close();
} }
wsActCloseNext = temp; wsActCloseNext = temp;
temp = null; temp = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
} }
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr //Hence: put inputs in working memory -> set back to default for next use of workspace mgr
if(i == 0 && wsm != WorkspaceMode.NONE){ if (i == 0 && wsm != WorkspaceMode.NONE) {
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
if(wsActCloseNext != null){ if(wsActCloseNext != null){
wsActCloseNext.close(); try {
wsActCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(temp != null){ if(temp != null){
//Should only be non-null on exception //Should only be non-null on exception
while(temp.isScopeActive()){ while(temp.isScopeActive()){
//For safety, should never occur in theory: a single close() call may not be sufficient, if //For safety, should never occur in theory: a single close() call may not be sufficient, if
// workspace scope was borrowed and not properly closed when exception occurred // workspace scope was borrowed and not properly closed when exception occurred
temp.close(); try{
temp.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = layers.length - 1; i >= 0; i--) { for (int i = layers.length - 1; i >= 0; i--) {
if (layers[i] instanceof FrozenLayer) { if (layers[i] instanceof FrozenLayer) {
break; break;
} }
if(traceLog){ if (traceLog) {
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
} }
@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually //Note that because we're opening activation workspaces not in a simple nested order, we'll manually
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current" // override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
if(!tbptt){ if (!tbptt) {
//Standard case //Standard case
currPair = layers[i].backpropGradient(eps, workspaceMgr); currPair = layers[i].backpropGradient(eps, workspaceMgr);
} else { } else {
@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if(currPair.getSecond() != null) { if (currPair.getSecond() != null) {
//Edge case: may be null for Embedding layer, for example //Edge case: may be null for Embedding layer, for example
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
false, "Backprop"); false, "Backprop");
@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
currPair = new Pair<>(currPair.getFirst(), currPair = new Pair<>(currPair.getFirst(),
this.layerWiseConfigurations.getInputPreProcess(i) this.layerWiseConfigurations.getInputPreProcess(i)
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
if (i > 0 && currPair.getSecond() != null){ if (i > 0 && currPair.getSecond() != null) {
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
true, "Backprop"); true, "Backprop");
} }
} }
if(i == 0 ){ if (i == 0) {
if(returnInputActGrad && currPair.getSecond() != null){ if (returnInputActGrad && currPair.getSecond() != null) {
currPair.setSecond(currPair.getSecond().detach()); currPair.setSecond(currPair.getSecond().detach());
} else { } else {
currPair.setSecond(null); currPair.setSecond(null);
} }
} }
if(wsActGradCloseNext != null){ if (wsActGradCloseNext != null) {
wsActGradCloseNext.close(); wsActGradCloseNext.close();
} }
wsActGradCloseNext = wsActGradTemp; wsActGradCloseNext = wsActGradTemp;
wsActGradTemp = null; wsActGradTemp = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
} }
} }
} catch (Throwable thr ){
t = thr;
} finally { } finally {
if(wsActGradCloseNext != null){ if(wsActGradCloseNext != null){
wsActGradCloseNext.close(); try {
wsActGradCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(wsActGradTemp != null){ if(wsActGradTemp != null) {
//Should only be non-null on exception //Should only be non-null on exception
wsActGradTemp.close(); try {
wsActGradTemp.close();
} catch (Throwable t2) {
if (t != null) {
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }

View File

@ -476,19 +476,36 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
std::vector<int8_t> NDArray::asByteVector() { std::vector<int8_t> NDArray::asByteVector() {
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
if (this->isView()) {
auto tmp = this->dup(this->ordering());
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); if (isS()) {
// string data type requires special treatment
syncToHost();
auto numWords = this->lengthOf();
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
auto dataLength = offsetsBuffer[numWords];
std::vector<int8_t> result(headerLength + dataLength);
delete tmp; memcpy(result.data(), getBuffer(), headerLength + dataLength);
return result;
} else {
// all other types are linear
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
if (this->isView()) {
auto tmp = this->dup(this->ordering());
syncToHost();
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
delete tmp;
} else {
syncToHost();
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
}
return result;
} }
else {
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
}
return result;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
T* NDArray::bufferAsT() const { T* NDArray::bufferAsT() const {
if (isS()) // FIXME: do we REALLY want sync here?
throw std::runtime_error("You can't use this method on String array");
syncToHost(); syncToHost();
return reinterpret_cast<T*>(getBuffer()); return reinterpret_cast<T*>(getBuffer());
@ -3202,20 +3217,39 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo())) } else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
return false; return false;
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 if (isS()) {
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
for (int e = 0; e < this->lengthOf(); e++) {
auto s1 = this->e<std::string>(e);
auto s2 = other->e<std::string>(e);
ExtraArguments extras({eps}); if (s1 != s2)
return false;
}
NDArray::prepareSpecialUse({&tmp}, {this, other}); return true;
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); } else {
NDArray::registerSpecialUse({&tmp}, {this, other}); // regular numeric types
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
synchronize("NDArray::equalsTo"); ExtraArguments extras({eps});
if (tmp.e<int>(0) > 0) NDArray::prepareSpecialUse({&tmp}, {this, other});
return false; NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
getSpecialBuffer(), getSpecialShapeInfo(),
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
other->getShapeInfo(), other->getSpecialBuffer(),
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
tmp.specialBuffer(), tmp.specialShapeInfo());
NDArray::registerSpecialUse({&tmp}, {this, other});
return true; synchronize("NDArray::equalsTo");
if (tmp.e<int>(0) > 0)
return false;
return true;
}
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -54,6 +54,7 @@
#include <graph/ExecutionResult.h> #include <graph/ExecutionResult.h>
#include <exceptions/graph_execution_exception.h> #include <exceptions/graph_execution_exception.h>
#include <exceptions/no_results_exception.h> #include <exceptions/no_results_exception.h>
#include <graph/FlatUtils.h>
namespace nd4j{ namespace nd4j{
namespace graph { namespace graph {
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
continue; continue;
NDArray* array = var->getNDArray(); auto array = var->getNDArray();
auto byteVector = array->asByteVector();
auto fBuffer = builder.CreateVector(byteVector); auto fArray = FlatUtils::toFlatArray(builder, *array);
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
auto fName = builder.CreateString(*(var->getName())); auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index()); auto id = CreateIntPair(builder, var->id(), var->index());

View File

@ -866,9 +866,10 @@ void initializeFunctions(Nd4jPointer *functions) {
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
Nd4jPointer pointer; Nd4jPointer pointer;
// cudaHostAllocMapped |cudaHostAllocPortable // cudaHostAllocMapped |cudaHostAllocPortable
cudaError_t res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault); auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
if (res != 0) if (res != 0)
pointer = 0L; throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res);
return pointer; return pointer;
} }
@ -884,7 +885,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
Nd4jPointer pointer; Nd4jPointer pointer;
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize); auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
if (res != 0) if (res != 0)
pointer = 0L; throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res);
return pointer; return pointer;
} }
@ -894,9 +895,9 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
* @param pointer pointer that'll be freed * @param pointer pointer that'll be freed
*/ */
int freeHost(Nd4jPointer pointer) { int freeHost(Nd4jPointer pointer) {
cudaError_t res = cudaFreeHost(reinterpret_cast<void *>(pointer)); auto res = cudaFreeHost(reinterpret_cast<void *>(pointer));
if (res != 0) if (res != 0)
pointer = 0L; throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res);
return 1L; return 1L;
} }
@ -907,9 +908,10 @@ int freeHost(Nd4jPointer pointer) {
* @param ptrToDeviceId pointer to deviceId. * @param ptrToDeviceId pointer to deviceId.
*/ */
int freeDevice(Nd4jPointer pointer, int deviceId) { int freeDevice(Nd4jPointer pointer, int deviceId) {
cudaError_t res = cudaFree(reinterpret_cast<void *>(pointer)); auto res = cudaFree(reinterpret_cast<void *>(pointer));
if (res != 0) if (res != 0)
pointer = 0L; throw nd4j::cuda_exception::build("cudaFree(...) failed", res);
return 1L; return 1L;
} }
@ -934,7 +936,7 @@ Nd4jPointer createStream() {
auto stream = new cudaStream_t(); auto stream = new cudaStream_t();
auto dZ = cudaStreamCreate(stream); auto dZ = cudaStreamCreate(stream);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaStreamCreate(...) failed"); throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ);
return stream; return stream;
} }
@ -944,23 +946,21 @@ Nd4jPointer createEvent() {
CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t)); CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t));
cudaError_t dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming); auto dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaEventCreateWithFlags(...) failed"); throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ);
return nativeEvent; return nativeEvent;
} }
int registerEvent(Nd4jPointer event, Nd4jPointer stream) { int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event); auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream); auto pStream = reinterpret_cast<cudaStream_t *>(stream);
cudaError_t dZ = cudaEventRecord(*pEvent, *pStream); auto dZ = cudaEventRecord(*pEvent, *pStream);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaEventRecord(...) failed"); throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ);
return 1; return 1;
} }
@ -1065,53 +1065,48 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
} }
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaError_t dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size)); auto dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaMemset(...) failed"); throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ);
return 1; return 1;
} }
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved); auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
cudaError_t dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream); auto dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaMemsetAsync(...) failed"); throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ);
return 1; return 1;
} }
int destroyEvent(Nd4jPointer event) { int destroyEvent(Nd4jPointer event) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event); auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaError_t dZ = cudaEventDestroy(*pEvent); auto dZ = cudaEventDestroy(*pEvent);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaEvenDestroy(...) failed"); throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ);
return 1; return 1;
} }
int streamSynchronize(Nd4jPointer stream) { int streamSynchronize(Nd4jPointer stream) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream); auto pStream = reinterpret_cast<cudaStream_t *>(stream);
cudaError_t dZ = cudaStreamSynchronize(*pStream); auto dZ = cudaStreamSynchronize(*pStream);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaStreamSynchronize(...) failed"); throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ);
return 1L; return 1L;
} }
int eventSynchronize(Nd4jPointer event) { int eventSynchronize(Nd4jPointer event) {
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event); auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
cudaError_t dZ = cudaEventSynchronize(*pEvent); auto dZ = cudaEventSynchronize(*pEvent);
checkCudaErrors(dZ);
if (dZ != 0) if (dZ != 0)
throw std::runtime_error("cudaEventSynchronize(...) failed"); throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ);
return 1L; return 1L;
} }
@ -2697,13 +2692,16 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte
auto result = op->execute(context); auto result = op->execute(context);
// FIXME: remove once CUDA backend is 100% ready auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream());
if (res != 0)
throw nd4j::cuda_exception::build("customOp execution failed", res);
for (auto v:context->fastpath_in()) { for (auto v:context->fastpath_in()) {
v->makeBothActual(); v->syncToDevice();
} }
for (auto v:context->fastpath_out()) { for (auto v:context->fastpath_out()) {
v->makeBothActual(); v->syncToDevice();
} }
return result; return result;

View File

@ -36,6 +36,8 @@ namespace nd4j {
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair); static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray); static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
}; };
} }
} }

View File

@ -102,5 +102,16 @@ namespace nd4j {
delete[] newShape; delete[] newShape;
return array; return array;
} }
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
auto byteVector = array.asByteVector();
auto fBuffer = builder.CreateVector(byteVector);
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
}
} }
} }

View File

@ -26,7 +26,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
template <typename T> template <typename T>
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
@ -108,14 +107,14 @@ namespace helpers {
template <typename T> template <typename T>
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) { static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
const int rowNum = input->rows(); const int rowNum = input->rows();
const int columnNum = input->columns(); const int columnNum = input->columns();
NDArray determinant = NDArrayFactory::create<T>(1.f); NDArray determinant = NDArrayFactory::create<T>(1.f);
NDArray compoundMatrix = *input; // copy NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
permutationMatrix.setIdentity(); permutationMatrix.setIdentity();
T pivotValue; // = T(0.0); T pivotValue; // = T(0.0);
@ -161,46 +160,43 @@ namespace helpers {
return determinant; return determinant;
} }
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
template <typename T> template <typename T>
static int determinant_(NDArray* input, NDArray* output) { static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
} }
return Status::OK(); return Status::OK();
} }
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int logAbsDeterminant_(NDArray* input, NDArray* output) { int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
matrix.p(row, input->e<T>(k)); matrix.p(row, input->e<T>(k));
} }
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr); NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
if (det.e<T>(0) != 0.f) if (det.e<T>(0) != 0.f)
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0)))); output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
} }
@ -208,25 +204,23 @@ template <typename T>
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static int inverse_(NDArray* input, NDArray* output) { static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
auto totalCount = output->lengthOf() / n2; auto totalCount = output->lengthOf() / n2;
output->assign(0.f); // fill up output tensor with zeros output->assign(0.f); // fill up output tensor with zeros
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace()); auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace()); auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
if (e) if (e)
@ -235,7 +229,7 @@ template <typename T>
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
matrix.p(row++, input->e<T>(k)); matrix.p(row++, input->e<T>(k));
} }
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0); T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16? // FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) { if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
@ -268,8 +262,7 @@ template <typename T>
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
@ -296,14 +289,13 @@ template <typename T>
return true; return true;
} }
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int cholesky_(NDArray* input, NDArray* output, bool inplace) { int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
@ -311,8 +303,8 @@ template <typename T>
if (!inplace) if (!inplace)
output->assign(0.f); // fill up output tensor with zeros only inplace=false output->assign(0.f); // fill up output tensor with zeros only inplace=false
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace()); std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
for (int e = 0; e < totalCount; e++) { for (int e = 0; e < totalCount; e++) {
@ -346,14 +338,13 @@ template <typename T>
} }
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
defaultContext = context; BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int logdetFunctor_(NDArray* input, NDArray* output) { int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
std::unique_ptr<NDArray> tempOutput(input->dup()); std::unique_ptr<NDArray> tempOutput(input->dup());
int res = cholesky_<T>(input, tempOutput.get(), false); int res = cholesky_<T>(context, input, tempOutput.get(), false);
if (res != ND4J_STATUS_OK) if (res != ND4J_STATUS_OK)
return res; return res;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
@ -372,7 +363,7 @@ template <typename T>
} }
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
} }
} }

View File

@ -907,6 +907,8 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
/*** max ***/ /*** max ***/
case 0: { case 0: {
coord2 = hstart;
coord3 = hend;
T max = -DataTypeUtils::max<T>(); T max = -DataTypeUtils::max<T>();
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {

View File

@ -31,8 +31,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
// template <typename T> // template <typename T>
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
// if (theFirst != theSecond) { // if (theFirst != theSecond) {
@ -198,36 +196,33 @@ namespace helpers {
} }
template<typename T> template<typename T>
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return; if (inputMatrix->isIdentityMatrix()) return;
auto stream = defaultContext->getCudaStream(); auto stream = context->getCudaStream();
// invert main diagonal // invert main diagonal
upvertKernel<T> << < 1, n, 512, *stream >> > upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal // invert the second diagonal
invertKernelLow<T> << < 1, n, 512, *stream >> > invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<< n, n, 512, *stream >> > invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
} }
template<typename T> template<typename T>
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows(); int n = inputMatrix->rows();
invertedMatrix->setIdentity(); invertedMatrix->setIdentity();
auto stream = defaultContext->getCudaStream(); auto stream = context->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return; return;
} }
@ -237,13 +232,12 @@ namespace helpers {
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertedMatrix->tickWriteDevice(); invertedMatrix->tickWriteDevice();
invertedMatrix->printIndexedBuffer("Step1 UP inversion"); invertedMatrix->printIndexedBuffer("Step1 UP inversion");
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
} }
@ -392,7 +386,6 @@ namespace helpers {
auto n = input->rows(); auto n = input->rows();
cusolverDnHandle_t cusolverH = nullptr; cusolverDnHandle_t cusolverH = nullptr;
cusolverStatus_t status = cusolverDnCreate(&cusolverH); cusolverStatus_t status = cusolverDnCreate(&cusolverH);
defaultContext = context;
if (CUSOLVER_STATUS_SUCCESS != status) { if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("Cannot create cuSolver handle", status); throw cuda_exception::build("Cannot create cuSolver handle", status);
} }
@ -528,24 +521,19 @@ namespace helpers {
input->tickWriteDevice(); input->tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template void lup_, BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
FLOAT_NATIVE);
template<typename T> template<typename T>
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
// DataType dtype = input->dataType(); // DataType dtype = input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
defaultContext = context; auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
defaultContext); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -554,8 +542,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -578,7 +565,6 @@ namespace helpers {
} }
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -586,19 +572,16 @@ namespace helpers {
template<typename T> template<typename T>
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
{input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType(); DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE) if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32; dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
defaultContext); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1); auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
@ -607,8 +590,7 @@ namespace helpers {
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
@ -620,8 +602,7 @@ namespace helpers {
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
// if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> > determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
(inputBuf, outputBuf, n);
// else // else
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
@ -633,7 +614,6 @@ namespace helpers {
} }
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -696,17 +676,16 @@ namespace helpers {
template<typename T> template<typename T>
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType(); auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
// if (dtype != DataType::DOUBLE) // if (dtype != DataType::DOUBLE)
// dtype = DataType::FLOAT32; // dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
{input->rankOf() - 2, {input->rankOf() - 2,
input->rankOf() - 1}); input->rankOf() - 1});
@ -716,20 +695,17 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) { for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T> << < 1, n2, 1024, *stream >> > fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
i * n2, n);
matrix.tickWriteDevice(); matrix.tickWriteDevice();
compound.assign(matrix); compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr); lup_<T>(context, &compound, nullptr, nullptr);
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> > fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
matrix.assign(0); matrix.assign(0);
invertUpperMatrix(&upper, &matrix); // U^{-1} invertUpperMatrix(context, &upper, &matrix); // U^{-1}
matrix.tickWriteDevice(); matrix.tickWriteDevice();
// matrix.printIndexedBuffer("Upper Inverted"); // matrix.printIndexedBuffer("Upper Inverted");
compound.assign(0); compound.assign(0);
invertLowerMatrix(&lower, &compound); // L{-1} invertLowerMatrix(context, &lower, &compound); // L{-1}
compound.tickWriteDevice(); compound.tickWriteDevice();
// compound.printIndexedBuffer("Lower Inverted"); // compound.printIndexedBuffer("Lower Inverted");
// matrix.tickWriteDevice(); // matrix.tickWriteDevice();
@ -737,15 +713,12 @@ namespace helpers {
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
upper.tickWriteDevice(); upper.tickWriteDevice();
// upper.printIndexedBuffer("Full inverted"); // upper.printIndexedBuffer("Full inverted");
returnMatrix<T> << < 1, n2, 1024, *stream >> > returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
i * n2, n);
} }
return Status::OK(); return Status::OK();
} }
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -788,7 +761,6 @@ namespace helpers {
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
if (!inplace) if (!inplace)
output->assign(input); output->assign(input);
defaultContext = context;
std::unique_ptr<NDArray> tempOutput(output->dup()); std::unique_ptr<NDArray> tempOutput(output->dup());
cusolverDnHandle_t handle = nullptr; cusolverDnHandle_t handle = nullptr;
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
@ -868,7 +840,6 @@ namespace helpers {
// template <typename T> // template <typename T>
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
if (input->dataType() == DataType::DOUBLE) if (input->dataType() == DataType::DOUBLE)
cholesky__<double>(context, input, output, inplace); cholesky__<double>(context, input, output, inplace);
@ -876,8 +847,7 @@ namespace helpers {
cholesky__<float>(context, input, output, inplace); cholesky__<float>(context, input, output, inplace);
else { else {
std::unique_ptr<NDArray> tempOutput( std::unique_ptr<NDArray> tempOutput(
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
defaultContext));
tempOutput->assign(input); tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true); cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get()); output->assign(tempOutput.get());
@ -888,7 +858,6 @@ namespace helpers {
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
defaultContext = context;
return cholesky_(context, input, output, inplace); return cholesky_(context, input, output, inplace);
} }
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
@ -927,7 +896,6 @@ namespace helpers {
template<typename T> template<typename T>
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
@ -957,7 +925,6 @@ namespace helpers {
} }
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
defaultContext = context;
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
} }

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <NDArray.h>
#include <NDArrayFactory.h>
#include "testlayers.h"
#include <graph/Stash.h>
#include <FlatUtils.h>
using namespace nd4j;
class FlatUtilsTests : public testing::Test {
public:
};
TEST_F(FlatUtilsTests, flat_float_serde_1) {
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_int_serde_1) {
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_string_serde_1) {
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}

View File

@ -24,7 +24,6 @@
#include "testlayers.h" #include "testlayers.h"
#include <graph/Stash.h> #include <graph/Stash.h>
using namespace nd4j;
using namespace nd4j; using namespace nd4j;
class StringTests : public testing::Test { class StringTests : public testing::Test {

View File

@ -31,10 +31,35 @@
<build> <build>
<plugins> <plugins>
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
change will run into a compilation error. This can be removed after a few weeks.-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-antrun-plugin</artifactId>
<version>1.8</version>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<target>
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
</target>
</configuration>
</execution>
</executions>
</plugin>
<plugin> <plugin>
<groupId>com.github.os72</groupId> <groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId> <artifactId>protoc-jar-maven-plugin</artifactId>
<version>3.5.1.1</version> <version>3.8.0</version>
<executions> <executions>
<execution> <execution>
<id>tensorflow</id> <id>tensorflow</id>
@ -43,30 +68,14 @@
<goal>run</goal> <goal>run</goal>
</goals> </goals>
<configuration> <configuration>
<type>java-shaded</type> <protocVersion>3.8.0</protocVersion>
<protocVersion>3.5.1</protocVersion> <extension>.proto</extension>
<includeDirectories> <includeDirectories>
<include>src/main/protobuf/tf</include> <include>src/main/protobuf/tf</include>
<include>src/main/protobuf/onnx</include>
</includeDirectories> </includeDirectories>
<inputDirectories> <inputDirectories>
<include>src/main/protobuf/tf/tensorflow</include> <include>src/main/protobuf/tf/tensorflow</include>
</inputDirectories>
<addSources>main</addSources>
<cleanOutputFolder>false</cleanOutputFolder>
<outputDirectory>src/main/java/</outputDirectory>
</configuration>
</execution>
<execution>
<id>onnx</id>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<type>java-shaded</type>
<extension>.proto3</extension>
<protocVersion>3.5.1</protocVersion>
<inputDirectories>
<include>src/main/protobuf/onnx</include> <include>src/main/protobuf/onnx</include>
</inputDirectories> </inputDirectories>
<addSources>main</addSources> <addSources>main</addSources>
@ -76,6 +85,32 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
<plugin>
<groupId>com.google.code.maven-replacer-plugin</groupId>
<artifactId>replacer</artifactId>
<version>1.5.3</version>
<configuration>
<includes>
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
<include>${project.build.sourceDirectory}/tensorflow/**</include>
<include>${project.build.sourceDirectory}/onnx/**</include>
</includes>
<token>com.google.protobuf.</token>
<value>org.nd4j.shade.protobuf.</value>
</configuration>
<executions>
<execution>
<id>replace-imports</id>
<phase>generate-sources</phase>
<goals>
<goal>replace</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin> <plugin>
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId> <artifactId>maven-compiler-plugin</artifactId>
@ -148,20 +183,15 @@
<version>${flatbuffers.version}</version> <version>${flatbuffers.version}</version>
</dependency> </dependency>
<!-- Note that this is shaded flatbuffers, see the protoc declaration above <!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging their own older (incompatible) protobuf versions-->
their own older protobuf versions-->
<dependency> <dependency>
<groupId>com.github.os72</groupId> <groupId>org.nd4j</groupId>
<artifactId>protobuf-java-shaded-351</artifactId> <artifactId>protobuf</artifactId>
<version>0.9</version> <version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-java-util-shaded-351</artifactId>
<version>0.9</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.objenesis</groupId> <groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId> <artifactId>objenesis</artifactId>

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
/** /**
* Initialize the function from the given * Initialize the function from the given
* {@link onnx.OnnxProto3.NodeProto} * {@link onnx.Onnx.NodeProto}
* @param node * @param node
*/ */
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
setInstanceId(); setInstanceId();
initFromOnnx(node, sameDiff, attributesForNode, graph); initFromOnnx(node, sameDiff, attributesForNode, graph);
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
/** /**
* Iniitialize the function from the given * Iniitialize the function from the given
* {@link onnx.OnnxProto3.NodeProto} * {@link onnx.Onnx.NodeProto}
* @param node * @param node
* @param initWith * @param initWith
* @param attributesForNode * @param attributesForNode
* @param graph * @param graph
*/ */
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph); public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);

View File

@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
import java.util.Objects; import java.util.Objects;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.descriptors.tensorflow; package org.nd4j.imports.descriptors.tensorflow;
import com.github.os72.protobuf351.TextFormat; import org.nd4j.shade.protobuf.TextFormat;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.tensorflow.framework.OpDef; import org.tensorflow.framework.OpDef;

View File

@ -16,8 +16,8 @@
package org.nd4j.imports.graphmapper; package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message; import org.nd4j.shade.protobuf.Message;
import com.github.os72.protobuf351.TextFormat; import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper; package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message; import org.nd4j.shade.protobuf.Message;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;

View File

@ -16,13 +16,13 @@
package org.nd4j.imports.graphmapper.onnx; package org.nd4j.imports.graphmapper.onnx;
import com.github.os72.protobuf351.ByteString; import org.nd4j.shade.protobuf.ByteString;
import com.github.os72.protobuf351.Message; import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats; import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs; import com.google.common.primitives.Longs;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -52,7 +52,7 @@ import java.util.*;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> { public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper(); private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) { public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try { try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile); Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) { for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString() + "\n"); bufferedWriter.write(node.toString() + "\n");
} }
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param node * @param node
* @param graph * @param graph
*/ */
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) { public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
val properties = on.mappingsForFunction(); val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName); val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) { public boolean isOpIgnoreException(Onnx.NodeProto node) {
return false; return false;
} }
@Override @Override
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) { public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
return function.opName(); return function.opName();
} }
@Override @Override
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) { public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node)); val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
/** /**
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) { public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) { for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i); val node = graph.getNode(i);
if(node.getName().equals(name)) if(node.getName().equals(name))
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) { public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
return false; return false;
} }
@Override @Override
public List<String> getControlDependencies(OnnxProto3.NodeProto node) { public List<String> getControlDependencies(Onnx.NodeProto node) {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }
@Override @Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) { public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try { try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) { for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString()); bufferedWriter.write(node.toString());
} }
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) { public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
/** /**
* Need to figure out why * Need to figure out why
* gpu_0/conv1_1 isn't present in VGG * gpu_0/conv1_1 isn't present in VGG
*/ */
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>(); Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
for(int i = 0; i < graphProto.getInputCount(); i++) { for(int i = 0; i < graphProto.getInputCount(); i++) {
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType()); ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
} }
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) { public String translateToSameDiffName(String name, Onnx.NodeProto node) {
return null; return null;
} }
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) { protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension. Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
newBuilder() newBuilder()
.setDimValue(-1) .setDimValue(-1)
.build(); .build();
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder() Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
.setShape( .setShape(
OnnxProto3.TensorShapeProto.newBuilder() Onnx.TensorShapeProto.newBuilder()
.addDim(dim) .addDim(dim)
.addDim(dim).build()) .addDim(dim).build())
.build(); .build();
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public Message.Builder getNewGraphBuilder() { public Message.Builder getNewGraphBuilder() {
return OnnxProto3.GraphProto.newBuilder(); return Onnx.GraphProto.newBuilder();
} }
@Override @Override
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException { public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph(); return Onnx.ModelProto.parseFrom(inputStream).getGraph();
} }
@Override @Override
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException { public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph(); return Onnx.ModelProto.parseFrom(inputStream).getGraph();
} }
@Override @Override
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState, public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride, OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) { OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType()); val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
if(differentialFunction == null) { if(differentialFunction == null) {
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType()); throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) { public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
return nd4jTypeFromOnnxType(tensorProto.getElemType()); return nd4jTypeFromOnnxType(tensorProto.getElemType());
} }
@Override @Override
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) { public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING; return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
} }
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param dataType the data type to convert * @param dataType the data type to convert
* @return the nd4j type for the onnx type * @return the nd4j type for the onnx type
*/ */
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) { public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
switch (dataType) { switch (dataType) {
case DOUBLE: return DataType.DOUBLE; case DOUBLE: return DataType.DOUBLE;
case FLOAT: return DataType.FLOAT; case FLOAT: return DataType.FLOAT;
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) { public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) { for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
if(attributeProto.getName().equals(key)) { if(attributeProto.getName().equals(key)) {
return attributeProto.getS().toString(); return attributeProto.getS().toString();
} }
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) { public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
return Longs.toArray(attributeProto.getT().getDimsList()); return Longs.toArray(attributeProto.getT().getDimsList());
} }
@Override @Override
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) { public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
return false; return false;
} }
@Override @Override
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) { public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
return false; return false;
} }
@Override @Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) { public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
DataType type = dataTypeForTensor(tensorProto, 0); DataType type = dataTypeForTensor(tensorProto, 0);
if(!tensorProto.isInitialized()) { if(!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized"); throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
} }
OnnxProto3.TensorProto tensor = null; Onnx.TensorProto tensor = null;
for(int i = 0; i < graph.getInitializerCount(); i++) { for(int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i); val initializer = graph.getInitializer(i);
if(initializer.getName().equals(tensorName)) { if(initializer.getName().equals(tensorName)) {
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
return arr; return arr;
} }
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) { public INDArray mapTensorProto(Onnx.TensorProto tensor) {
if(tensor == null) if(tensor == null)
return null; return null;
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
} }
@Override @Override
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) { public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())]; val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
int dimCount = tensorProto.getShape().getDimCount(); int dimCount = tensorProto.getShape().getDimCount();
if(dimCount >= 2) if(dimCount >= 2)
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
/** /**
* Get the shape from a tensor proto. * Get the shape from a tensor proto.
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)} * Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
* @param tensorProto the tensor to get the shape from * @param tensorProto the tensor to get the shape from
* @return * @return
*/ */
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) { public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
val ret = new long[Math.max(2,tensorProto.getDimsCount())]; val ret = new long[Math.max(2,tensorProto.getDimsCount())];
int dimCount = tensorProto.getDimsCount(); int dimCount = tensorProto.getDimsCount();
if(dimCount >= 2) if(dimCount >= 2)
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override @Override
public String getInputFromNode(OnnxProto3.NodeProto node, int index) { public String getInputFromNode(Onnx.NodeProto node, int index) {
return node.getInput(index); return node.getInput(index);
} }
@Override @Override
public int numInputsFor(OnnxProto3.NodeProto nodeProto) { public int numInputsFor(Onnx.NodeProto nodeProto) {
return nodeProto.getInputCount(); return nodeProto.getInputCount();
} }
@Override @Override
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) { public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
return Longs.toArray(attr.getT().getDimsList()); return Longs.toArray(attr.getT().getDimsList());
} }
@Override @Override
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) { public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>(); Map<String,Onnx.AttributeProto> proto = new HashMap<>();
for(int i = 0; i < nodeProto.getAttributeCount(); i++) { for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i); Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
proto.put(attributeProto.getName(),attributeProto); proto.put(attributeProto.getName(),attributeProto);
} }
return proto; return proto;
} }
@Override @Override
public String getName(OnnxProto3.NodeProto nodeProto) { public String getName(Onnx.NodeProto nodeProto) {
return nodeProto.getName(); return nodeProto.getName();
} }
@Override @Override
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) { public boolean alreadySeen(Onnx.NodeProto nodeProto) {
return false; return false;
} }
@Override @Override
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) { public boolean isVariableNode(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType().contains("Var"); return nodeProto.getOpType().contains("Var");
} }
@Override @Override
public boolean shouldSkip(OnnxProto3.NodeProto opType) { public boolean shouldSkip(Onnx.NodeProto opType) {
return false; return false;
} }
@Override @Override
public boolean hasShape(OnnxProto3.NodeProto nodeProto) { public boolean hasShape(Onnx.NodeProto nodeProto) {
return false; return false;
} }
@Override @Override
public long[] getShape(OnnxProto3.NodeProto nodeProto) { public long[] getShape(Onnx.NodeProto nodeProto) {
return null; return null;
} }
@Override @Override
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) { public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
return null; return null;
} }
@Override @Override
public String getOpType(OnnxProto3.NodeProto nodeProto) { public String getOpType(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType(); return nodeProto.getOpType();
} }
@Override @Override
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) { public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
return graphProto.getNodeList(); return graphProto.getNodeList();
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper.tf; package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Message; import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats; import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;

View File

@ -1,6 +1,6 @@
package org.nd4j.imports.graphmapper.tf.tensors; package org.nd4j.imports.graphmapper.tf.tensors;
import com.github.os72.protobuf351.Descriptors; import org.nd4j.shade.protobuf.Descriptors;
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer; import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer; import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -20,7 +20,7 @@ import lombok.Data;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }
@Override @Override

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("axes")) { if (!attributesForNode.containsKey("axes")) {
this.dimensions = new int[] { Integer.MAX_VALUE }; this.dimensions = new int[] { Integer.MAX_VALUE };
} }

View File

@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs; import com.google.common.primitives.Longs;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops; package org.nd4j.linalg.api.ops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers; package org.nd4j.linalg.api.ops.impl.layers;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
import lombok.Builder; import lombok.Builder;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList(); val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs(); addArgs();
} }

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs(); addArgs();
} }

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val dilations = attributesForNode.get("dilations"); val dilations = attributesForNode.get("dilations");
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue(); val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs(); addArgs();
} }

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val aAlpha = attributesForNode.get("alpha"); val aAlpha = attributesForNode.get("alpha");
val aBeta = attributesForNode.get("beta"); val aBeta = attributesForNode.get("beta");
val aBias = attributesForNode.get("bias"); val aBias = attributesForNode.get("bias");

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val isSameNode = paddingVal.equals("SAME"); val isSameNode = paddingVal.equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); val kernelShape = attributesForNode.get("kernel_shape").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME"); val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = attributesForNode.get("pads").getIntsList(); val padding = attributesForNode.get("pads").getIntsList();

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent; package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder() MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs; import com.google.common.primitives.Longs;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder() MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName()); throw new UnsupportedOperationException("No analog found for onnx for " + opName());
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val shape = new OnnxGraphMapper().getShape(node); val shape = new OnnxGraphMapper().getShape(node);
this.shape = shape; this.shape = shape;
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import onnx.OnnxMlProto3; import onnx.OnnxMl;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new NoOpNameFoundException("No onnx name found for shape " + opName()); throw new NoOpNameFoundException("No onnx name found for shape " + opName());
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName()); throw new UnsupportedOperationException("No analog found for onnx for " + opName());
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import com.google.common.primitives.Ints; import com.google.common.primitives.Ints;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("perm")) { if (!attributesForNode.containsKey("perm")) {
} else } else

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName()); throw new UnsupportedOperationException("No analog found for onnx for " + opName());
} }

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
//No op //No op
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops; package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip; package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip; package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented"); throw new UnsupportedOperationException("Not yet implemented");
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom; package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl; package org.nd4j.linalg.api.ops.random.impl;
import lombok.NonNull; import lombok.NonNull;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
} }
@Override @Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph); super.initFromOnnx(node, initWith, attributesForNode, graph);
} }

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl; package org.nd4j.linalg.api.ops.random.impl;
import lombok.val; import lombok.val;
import onnx.OnnxProto3; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;

View File

@ -9,7 +9,7 @@
syntax = "proto3"; syntax = "proto3";
package onnx; package onnx;
import "onnx.proto3"; import "onnx.proto";
// //
// This file contains the proto definitions for OperatorSetProto and // This file contains the proto definitions for OperatorSetProto and

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion; package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat; import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Rule; import org.junit.Rule;

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion; package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat; import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;

View File

@ -732,4 +732,20 @@ public class CustomOpsTests extends BaseNd4jTest {
fail("Failed datatypes: " + failed.toString()); fail("Failed datatypes: " + failed.toString());
} }
} }
@Test
public void testMaxPool2Dbp_1() {
val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN);
val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN);
val z = Nd4j.create(DataType.HALF, 2,3,16,16);
val op = DynamicCustomOp.builder("maxpool2d_bp")
.addInputs(x, y)
.addOutputs(z)
.addIntegerArguments(2, 2, 2, 2, 8,8, 1,1,1, 0,0)
.build();
Nd4j.exec(op);
Nd4j.getExecutioner().commit();
}
} }

View File

@ -29,6 +29,7 @@
<packaging>pom</packaging> <packaging>pom</packaging>
<modules> <modules>
<module>jackson</module> <module>jackson</module>
<module>protobuf</module>
</modules> </modules>
<properties> <properties>

View File

@ -0,0 +1,228 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nd4j-shade</artifactId>
<groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>protobuf</artifactId>
<properties>
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.8.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.8.0</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>custom-lifecycle</id>
<activation>
<property><name>!skip.custom.lifecycle</name></property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.portals.jetspeed-2</groupId>
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
<version>2.3.1</version>
<executions>
<execution>
<id>compile-and-pack</id>
<phase>compile</phase>
<goals>
<goal>mvn</goal>
</goals>
</execution>
</executions>
<dependencies>
<dependency>
<groupId>org.apache.maven.shared</groupId>
<artifactId>maven-invoker</artifactId>
<version>2.2</version>
</dependency>
</dependencies>
<configuration>
<targets combine.children="merge">
<target>
<id>create-shaded-jars</id>
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
<goals>clean,compile,package</goals>
<properties>
<skip.custom.lifecycle>true</skip.custom.lifecycle>
</properties>
</target>
</targets>
<defaultTarget>create-shaded-jars</defaultTarget>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<plugins>
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
<plugin>
<groupId>com.lewisd</groupId>
<artifactId>lint-maven-plugin</artifactId>
<version>0.0.11</version>
<executions>
<execution>
<id>pom-lint</id>
<phase>none</phase>
</execution>
</executions>
</plugin>
<!--
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
including this module (org.nd4j.protobuf) as a dependency.
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
<configuration>
<!--
Important configuration options here:
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
original dependencies will be shaded, AND still included as transitive deps
in the final POM. This is not what we want.
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
the original un-shaded JAR being separate. With this being set to false,
the original JAR will be modified, and no extra jar will be produced.
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
to direct dependencies. Without this, we need to manually manage the transitive
dependencies of the shaded artifacts.
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
-->
<shadedArtifactAttached>false</shadedArtifactAttached>
<createDependencyReducedPom>true</createDependencyReducedPom>
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
<artifactSet>
<includes>
<include>com.google.protobuf:*</include>
<include>com.google.protobuf.*:*</include>
</includes>
</artifactSet>
<relocations>
<!-- Protobuf dependencies -->
<relocation>
<pattern>com.google.protobuf</pattern>
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
</relocation>
</relocations>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<forceCreation>true</forceCreation>
</configuration>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
<execution>
<id>empty-sources-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>sources</classifier>
<classesDirectory>${basedir}/src</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>unpack</id>
<phase>package</phase>
<goals>
<goal>unpack</goal>
</goals>
<configuration>
<artifactItems>
<artifactItem>
<groupId>org.nd4j</groupId>
<artifactId>protobuf</artifactId>
<version>${project.version}</version>
<type>jar</type>
<overWrite>false</overWrite>
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
<includes>**/*.class,**/*.xml</includes>
</artifactItem>
</artifactItems>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion; package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.InvalidProtocolBufferException; import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*; import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;

View File

@ -16,9 +16,9 @@
package org.nd4j.tensorflow.conversion.graphrunner; package org.nd4j.tensorflow.conversion.graphrunner;
import com.github.os72.protobuf351.ByteString; import org.nd4j.shade.protobuf.ByteString;
import com.github.os72.protobuf351.InvalidProtocolBufferException; import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import com.github.os72.protobuf351.util.JsonFormat; import org.nd4j.shade.protobuf.util.JsonFormat;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
/** /**
* Convert a json string written out * Convert a json string written out
* by {@link com.github.os72.protobuf351.util.JsonFormat} * by {@link org.nd4j.shade.protobuf.util.JsonFormat}
* to a {@link org.bytedeco.tensorflow.ConfigProto} * to a {@link org.bytedeco.tensorflow.ConfigProto}
* @param json the json to read * @param json the json to read
* @return the config proto to use * @return the config proto to use