[WIP] latest update (#8145)
* [WIP] maxpool2d_bp fix (#160) * one test for maxpool2d_bp Signed-off-by: raver119 <raver119@gmail.com> * - maxpool2d_bp cuda fix for NaNs - streamSync after each custom op execution Signed-off-by: raver119 <raver119@gmail.com> * MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade protobuf version (#162) * First steps for protobuf version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * Phase 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update imports to shaded protobuf Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch to single execution for protobuf codegen to work around plugin bug Signed-off-by: AlexDBlack <blacka101@gmail.com> * Automatically delete old PB generated files after name change Signed-off-by: Alex Black <blacka101@gmail.com> * - string NDArray flat serde impl + tests (#163) - string NDArray equalsTo impl Signed-off-by: raver119 <raver119@gmail.com> * get rid of context variable Signed-off-by: raver119 <raver119@gmail.com> * lup context fix (#164) Signed-off-by: raver119 <raver119@gmail.com>master
parent
95b2686ce5
commit
d871eab2e5
|
@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
||||
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
||||
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = 0; i <= stopIndex; i++) {
|
||||
GraphVertex current = vertices[topologicalOrder[i]];
|
||||
|
@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
.build();
|
||||
|
||||
if(detachedInputs){
|
||||
if (detachedInputs) {
|
||||
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be
|
||||
// in a workspace
|
||||
workspaceMgr.setScopedOutFor(ArrayType.INPUT);
|
||||
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
|
||||
} else {
|
||||
//Don't leverage out of async MultiDataSetIterator workspaces
|
||||
if(features[0].isAttached()){
|
||||
if (features[0].isAttached()) {
|
||||
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
|
||||
}
|
||||
}
|
||||
|
@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
if (ArrayUtils.contains(layerIndexes, vIdx)) {
|
||||
isRequiredOutput = true;
|
||||
|
||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
||||
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||
//Place activations in user-specified workspace
|
||||
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||
|
@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Open the relevant workspace for the activations.
|
||||
//Note that this will be closed only once the current vertex's activations have been consumed
|
||||
MemoryWorkspace wsActivations = null;
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
||||
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
||||
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
|
||||
openActivationsWorkspaces.put(wsActivations, workspaceMgr);
|
||||
}
|
||||
|
@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace
|
||||
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
|
||||
// close these workspaces, the "current" workspace may be set to the incorrect one
|
||||
if(wsActivations != null )
|
||||
if (wsActivations != null)
|
||||
wsActivations.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
||||
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
||||
if (closeAtEndIteraton[closeableAt] == null) {
|
||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||
}
|
||||
|
@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
out = features[vIdx];
|
||||
} else {
|
||||
|
||||
if(fwdPassType == FwdPassType.STANDARD){
|
||||
if (fwdPassType == FwdPassType.STANDARD) {
|
||||
//Standard feed-forward case
|
||||
out = current.doForward(train, workspaceMgr);
|
||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
||||
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||
if (current.hasLayer()) {
|
||||
//Layer
|
||||
INDArray input = current.getInputs()[0];
|
||||
Layer l = current.getLayer();
|
||||
if (l instanceof RecurrentLayer) {
|
||||
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying());
|
||||
} else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
|
||||
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||
} else if (l instanceof MultiLayerNetwork) {
|
||||
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
|
||||
|
@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
|
||||
}
|
||||
|
||||
if(inputsTo != null) { //Output vertices may not input to any other vertices
|
||||
if (inputsTo != null) { //Output vertices may not input to any other vertices
|
||||
for (VertexIndices v : inputsTo) {
|
||||
//Note that we don't have to do anything special here: the activations are always detached in
|
||||
// this method
|
||||
|
@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
}
|
||||
}
|
||||
|
||||
if(clearLayerInputs) {
|
||||
if (clearLayerInputs) {
|
||||
current.clear();
|
||||
}
|
||||
|
||||
if(isRequiredOutput){
|
||||
if (isRequiredOutput) {
|
||||
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
|
||||
if(origWSAct != null){
|
||||
if (origWSAct != null) {
|
||||
//Reset the configuration, as we may reuse this workspace manager...
|
||||
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
|
||||
}
|
||||
|
@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Close any activations workspaces that we no longer require
|
||||
//Note that activations workspaces can be closed only once the corresponding output activations have
|
||||
// been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||
//Though if stopIndex < numLayers, some might still be open
|
||||
|
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
||||
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
||||
// number of times to actually close it, in all cases
|
||||
ws.close();
|
||||
try{
|
||||
ws.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
try{
|
||||
for(int i=topologicalOrder.length-1; i>= 0; i--){
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
||||
boolean hitFrozen = false;
|
||||
GraphVertex current = vertices[topologicalOrder[i]];
|
||||
int vIdx = current.getVertexIndex();
|
||||
String vertexName = current.getVertexName();
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//FIXME: make the frozen vertex feature extraction more flexible
|
||||
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){
|
||||
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
|
||||
hitFrozen = true;
|
||||
}
|
||||
|
||||
if (current.isInputVertex() || hitFrozen){
|
||||
if (current.isInputVertex() || hitFrozen) {
|
||||
//Close any activation gradient workspaces that we no longer require
|
||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||
// gradients have been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
|
@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
wsActivationGrads.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
|
||||
if(closeableAt >= 0) {
|
||||
if (closeableAt >= 0) {
|
||||
if (closeAtEndIteraton[closeableAt] == null) {
|
||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||
}
|
||||
|
@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
Pair<Gradient, INDArray[]> pair;
|
||||
INDArray[] epsilons;
|
||||
try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
||||
try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||
pair = current.doBackward(truncatedBPTT, workspaceMgr);
|
||||
epsilons = pair.getSecond();
|
||||
|
||||
//Validate workspace location for the activation gradients:
|
||||
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
|
||||
for (INDArray epsilon : epsilons) {
|
||||
if(epsilon != null) {
|
||||
if (epsilon != null) {
|
||||
//May be null for EmbeddingLayer, etc
|
||||
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
|
||||
}
|
||||
|
@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
||||
g.flatteningOrderForVariable(origName)));
|
||||
}
|
||||
for (Triple<String, INDArray, Character> t : tempList)
|
||||
gradients.addFirst(t);
|
||||
for (Triple<String, INDArray, Character> triple : tempList)
|
||||
gradients.addFirst(triple);
|
||||
}
|
||||
|
||||
//Close any activation gradient workspaces that we no longer require
|
||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||
// gradients have been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
|
@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
closeAtEndIteraton[i] = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
||||
ws.close();
|
||||
try{
|
||||
ws.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
}
|
||||
|
||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||
Gradient gradient = new DefaultGradient(flattenedGradients);
|
||||
for (Triple<String, INDArray, Character> t : gradients) {
|
||||
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
|
||||
for (Triple<String, INDArray, Character> tr : gradients) {
|
||||
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
|
||||
}
|
||||
|
||||
this.gradient = gradient;
|
||||
|
|
|
@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = 0; i <= layerIndex; i++) {
|
||||
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||
//Hence: put inputs in working memory
|
||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
||||
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
|
||||
}
|
||||
|
||||
|
@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
temp.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
|
||||
if(i == 0 && input.isAttached()){
|
||||
if (i == 0 && input.isAttached()) {
|
||||
//Don't leverage out of async DataSetIterator workspaces
|
||||
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
|
||||
}
|
||||
|
@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
|
||||
}
|
||||
|
||||
if ( i == layerIndex ) {
|
||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
||||
if (i == layerIndex) {
|
||||
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||
//Place activations in user-specified workspace
|
||||
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
|
||||
} else {
|
||||
|
@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
|
||||
if(fwdPassType == FwdPassType.STANDARD){
|
||||
if (fwdPassType == FwdPassType.STANDARD) {
|
||||
//Standard feed-forward case
|
||||
input = layers[i].activate(input, train, mgr);
|
||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
||||
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||
//rnnTimeStep case
|
||||
if (layers[i] instanceof RecurrentLayer) {
|
||||
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying());
|
||||
} else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) {
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying());
|
||||
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
|
||||
|
@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
//Validation: Exception if invalid (bad layer implementation)
|
||||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
|
||||
|
||||
if(wsActCloseNext != null){
|
||||
if (wsActCloseNext != null) {
|
||||
wsActCloseNext.close();
|
||||
}
|
||||
wsActCloseNext = temp;
|
||||
temp = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr
|
||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
||||
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
if(wsActCloseNext != null){
|
||||
wsActCloseNext.close();
|
||||
try {
|
||||
wsActCloseNext.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(temp != null){
|
||||
//Should only be non-null on exception
|
||||
while(temp.isScopeActive()){
|
||||
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
||||
// workspace scope was borrowed and not properly closed when exception occurred
|
||||
temp.close();
|
||||
try{
|
||||
temp.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = layers.length - 1; i >= 0; i--) {
|
||||
if (layers[i] instanceof FrozenLayer) {
|
||||
break;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
|
@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
|
||||
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
|
||||
try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
||||
try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||
|
||||
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually
|
||||
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
|
||||
|
@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
|
||||
|
||||
if(!tbptt){
|
||||
if (!tbptt) {
|
||||
//Standard case
|
||||
currPair = layers[i].backpropGradient(eps, workspaceMgr);
|
||||
} else {
|
||||
|
@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
|
||||
if(currPair.getSecond() != null) {
|
||||
if (currPair.getSecond() != null) {
|
||||
//Edge case: may be null for Embedding layer, for example
|
||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||
false, "Backprop");
|
||||
|
@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
currPair = new Pair<>(currPair.getFirst(),
|
||||
this.layerWiseConfigurations.getInputPreProcess(i)
|
||||
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
|
||||
if (i > 0 && currPair.getSecond() != null){
|
||||
if (i > 0 && currPair.getSecond() != null) {
|
||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||
true, "Backprop");
|
||||
}
|
||||
}
|
||||
|
||||
if(i == 0 ){
|
||||
if(returnInputActGrad && currPair.getSecond() != null){
|
||||
if (i == 0) {
|
||||
if (returnInputActGrad && currPair.getSecond() != null) {
|
||||
currPair.setSecond(currPair.getSecond().detach());
|
||||
} else {
|
||||
currPair.setSecond(null);
|
||||
}
|
||||
}
|
||||
|
||||
if(wsActGradCloseNext != null){
|
||||
if (wsActGradCloseNext != null) {
|
||||
wsActGradCloseNext.close();
|
||||
}
|
||||
wsActGradCloseNext = wsActGradTemp;
|
||||
wsActGradTemp = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
} catch (Throwable thr ){
|
||||
t = thr;
|
||||
} finally {
|
||||
if(wsActGradCloseNext != null){
|
||||
wsActGradCloseNext.close();
|
||||
try {
|
||||
wsActGradCloseNext.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(wsActGradTemp != null){
|
||||
if(wsActGradTemp != null) {
|
||||
//Should only be non-null on exception
|
||||
wsActGradTemp.close();
|
||||
try {
|
||||
wsActGradTemp.close();
|
||||
} catch (Throwable t2) {
|
||||
if (t != null) {
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
}
|
||||
|
|
|
@ -476,19 +476,36 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
std::vector<int8_t> NDArray::asByteVector() {
|
||||
|
||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||
|
||||
if (this->isView()) {
|
||||
auto tmp = this->dup(this->ordering());
|
||||
|
||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
if (isS()) {
|
||||
// string data type requires special treatment
|
||||
syncToHost();
|
||||
auto numWords = this->lengthOf();
|
||||
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
|
||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
|
||||
auto dataLength = offsetsBuffer[numWords];
|
||||
std::vector<int8_t> result(headerLength + dataLength);
|
||||
|
||||
delete tmp;
|
||||
memcpy(result.data(), getBuffer(), headerLength + dataLength);
|
||||
|
||||
return result;
|
||||
} else {
|
||||
// all other types are linear
|
||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||
|
||||
if (this->isView()) {
|
||||
auto tmp = this->dup(this->ordering());
|
||||
syncToHost();
|
||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
|
||||
delete tmp;
|
||||
} else {
|
||||
syncToHost();
|
||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T* NDArray::bufferAsT() const {
|
||||
if (isS())
|
||||
throw std::runtime_error("You can't use this method on String array");
|
||||
|
||||
// FIXME: do we REALLY want sync here?
|
||||
syncToHost();
|
||||
|
||||
return reinterpret_cast<T*>(getBuffer());
|
||||
|
@ -3202,20 +3217,39 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
|||
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
||||
return false;
|
||||
|
||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||
if (isS()) {
|
||||
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||
for (int e = 0; e < this->lengthOf(); e++) {
|
||||
auto s1 = this->e<std::string>(e);
|
||||
auto s2 = other->e<std::string>(e);
|
||||
|
||||
ExtraArguments extras({eps});
|
||||
if (s1 != s2)
|
||||
return false;
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||
return true;
|
||||
} else {
|
||||
// regular numeric types
|
||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||
|
||||
synchronize("NDArray::equalsTo");
|
||||
ExtraArguments extras({eps});
|
||||
|
||||
if (tmp.e<int>(0) > 0)
|
||||
return false;
|
||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
||||
getSpecialBuffer(), getSpecialShapeInfo(),
|
||||
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
|
||||
other->getShapeInfo(), other->getSpecialBuffer(),
|
||||
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
|
||||
tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||
|
||||
return true;
|
||||
synchronize("NDArray::equalsTo");
|
||||
|
||||
if (tmp.e<int>(0) > 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
#include <graph/ExecutionResult.h>
|
||||
#include <exceptions/graph_execution_exception.h>
|
||||
#include <exceptions/no_results_exception.h>
|
||||
#include <graph/FlatUtils.h>
|
||||
|
||||
namespace nd4j{
|
||||
namespace graph {
|
||||
|
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
continue;
|
||||
|
||||
|
||||
NDArray* array = var->getNDArray();
|
||||
auto byteVector = array->asByteVector();
|
||||
auto array = var->getNDArray();
|
||||
|
||||
auto fBuffer = builder.CreateVector(byteVector);
|
||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
|
||||
auto fArray = FlatUtils::toFlatArray(builder, *array);
|
||||
|
||||
auto fName = builder.CreateString(*(var->getName()));
|
||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||
|
|
|
@ -866,9 +866,10 @@ void initializeFunctions(Nd4jPointer *functions) {
|
|||
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
||||
Nd4jPointer pointer;
|
||||
// cudaHostAllocMapped |cudaHostAllocPortable
|
||||
cudaError_t res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
|
||||
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
|
||||
if (res != 0)
|
||||
pointer = 0L;
|
||||
throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res);
|
||||
|
||||
return pointer;
|
||||
}
|
||||
|
||||
|
@ -884,7 +885,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
|||
Nd4jPointer pointer;
|
||||
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
|
||||
if (res != 0)
|
||||
pointer = 0L;
|
||||
throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res);
|
||||
return pointer;
|
||||
}
|
||||
|
||||
|
@ -894,9 +895,9 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
|||
* @param pointer pointer that'll be freed
|
||||
*/
|
||||
int freeHost(Nd4jPointer pointer) {
|
||||
cudaError_t res = cudaFreeHost(reinterpret_cast<void *>(pointer));
|
||||
auto res = cudaFreeHost(reinterpret_cast<void *>(pointer));
|
||||
if (res != 0)
|
||||
pointer = 0L;
|
||||
throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res);
|
||||
return 1L;
|
||||
}
|
||||
|
||||
|
@ -907,9 +908,10 @@ int freeHost(Nd4jPointer pointer) {
|
|||
* @param ptrToDeviceId pointer to deviceId.
|
||||
*/
|
||||
int freeDevice(Nd4jPointer pointer, int deviceId) {
|
||||
cudaError_t res = cudaFree(reinterpret_cast<void *>(pointer));
|
||||
auto res = cudaFree(reinterpret_cast<void *>(pointer));
|
||||
if (res != 0)
|
||||
pointer = 0L;
|
||||
throw nd4j::cuda_exception::build("cudaFree(...) failed", res);
|
||||
|
||||
return 1L;
|
||||
}
|
||||
|
||||
|
@ -934,7 +936,7 @@ Nd4jPointer createStream() {
|
|||
auto stream = new cudaStream_t();
|
||||
auto dZ = cudaStreamCreate(stream);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaStreamCreate(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ);
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
@ -944,23 +946,21 @@ Nd4jPointer createEvent() {
|
|||
|
||||
CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t));
|
||||
|
||||
cudaError_t dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaEventCreateWithFlags(reinterpret_cast<cudaEvent_t *>(&nativeEvent), cudaEventDisableTiming);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaEventCreateWithFlags(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ);
|
||||
|
||||
|
||||
return nativeEvent;
|
||||
}
|
||||
|
||||
int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
|
||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||
|
||||
cudaError_t dZ = cudaEventRecord(*pEvent, *pStream);
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaEventRecord(*pEvent, *pStream);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaEventRecord(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
@ -1065,53 +1065,48 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
|
|||
}
|
||||
|
||||
int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
cudaError_t dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaMemset(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size));
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaMemset(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||
auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||
|
||||
cudaError_t dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaMemsetAsync(reinterpret_cast<void *>(dst), value, static_cast<size_t>(size), *pStream);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaMemsetAsync(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int destroyEvent(Nd4jPointer event) {
|
||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
cudaError_t dZ = cudaEventDestroy(*pEvent);
|
||||
checkCudaErrors(dZ);
|
||||
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
auto dZ = cudaEventDestroy(*pEvent);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaEvenDestroy(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int streamSynchronize(Nd4jPointer stream) {
|
||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||
auto pStream = reinterpret_cast<cudaStream_t *>(stream);
|
||||
|
||||
cudaError_t dZ = cudaStreamSynchronize(*pStream);
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaStreamSynchronize(*pStream);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaStreamSynchronize(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ);
|
||||
|
||||
return 1L;
|
||||
}
|
||||
|
||||
int eventSynchronize(Nd4jPointer event) {
|
||||
cudaEvent_t *pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
auto pEvent = reinterpret_cast<cudaEvent_t *>(&event);
|
||||
|
||||
cudaError_t dZ = cudaEventSynchronize(*pEvent);
|
||||
checkCudaErrors(dZ);
|
||||
auto dZ = cudaEventSynchronize(*pEvent);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaEventSynchronize(...) failed");
|
||||
throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ);
|
||||
|
||||
return 1L;
|
||||
}
|
||||
|
@ -2697,13 +2692,16 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte
|
|||
|
||||
auto result = op->execute(context);
|
||||
|
||||
// FIXME: remove once CUDA backend is 100% ready
|
||||
auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream());
|
||||
if (res != 0)
|
||||
throw nd4j::cuda_exception::build("customOp execution failed", res);
|
||||
|
||||
for (auto v:context->fastpath_in()) {
|
||||
v->makeBothActual();
|
||||
v->syncToDevice();
|
||||
}
|
||||
|
||||
for (auto v:context->fastpath_out()) {
|
||||
v->makeBothActual();
|
||||
v->syncToDevice();
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
@ -36,6 +36,8 @@ namespace nd4j {
|
|||
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
||||
|
||||
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
||||
|
||||
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,5 +102,16 @@ namespace nd4j {
|
|||
delete[] newShape;
|
||||
return array;
|
||||
}
|
||||
|
||||
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
|
||||
auto byteVector = array.asByteVector();
|
||||
|
||||
auto fBuffer = builder.CreateVector(byteVector);
|
||||
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
|
||||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -26,7 +26,6 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
||||
|
||||
template <typename T>
|
||||
static void swapRows_(NDArray* matrix, int theFirst, int theSecond) {
|
||||
|
@ -108,14 +107,14 @@ namespace helpers {
|
|||
|
||||
|
||||
template <typename T>
|
||||
static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||
|
||||
const int rowNum = input->rows();
|
||||
const int columnNum = input->columns();
|
||||
|
||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
||||
NDArray compoundMatrix = *input; // copy
|
||||
NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides
|
||||
NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides
|
||||
permutationMatrix.setIdentity();
|
||||
|
||||
T pivotValue; // = T(0.0);
|
||||
|
@ -161,46 +160,43 @@ namespace helpers {
|
|||
return determinant;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
static int determinant_(NDArray* input, NDArray* output) {
|
||||
static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
||||
matrix.p(row, input->e<T>(k));
|
||||
output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
output->p(e, lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int logAbsDeterminant_(NDArray* input, NDArray* output) {
|
||||
int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace());
|
||||
NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace());
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) {
|
||||
matrix.p(row, input->e<T>(k));
|
||||
}
|
||||
NDArray det = lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
NDArray det = lup_<T>(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr);
|
||||
if (det.e<T>(0) != 0.f)
|
||||
output->p(e, nd4j::math::nd4j_log<T,T>(nd4j::math::nd4j_abs(det.t<T>(0))));
|
||||
}
|
||||
|
@ -208,25 +204,23 @@ template <typename T>
|
|||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
|
||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int inverse_(NDArray* input, NDArray* output) {
|
||||
static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
auto totalCount = output->lengthOf() / n2;
|
||||
|
||||
output->assign(0.f); // fill up output tensor with zeros
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), defaultContext);
|
||||
auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT<T>(), context);
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
if (e)
|
||||
|
@ -235,7 +229,7 @@ template <typename T>
|
|||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
matrix.p(row++, input->e<T>(k));
|
||||
}
|
||||
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
||||
T det = lup_<T>(context, &matrix, &compound, &permutation).template e<T>(0);
|
||||
|
||||
// FIXME: and how this is going to work on float16?
|
||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||
|
@ -268,8 +262,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -296,14 +289,13 @@ template <typename T>
|
|||
|
||||
return true;
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES);
|
||||
|
||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
|
||||
int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) {
|
||||
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
|
@ -311,8 +303,8 @@ template <typename T>
|
|||
if (!inplace)
|
||||
output->assign(0.f); // fill up output tensor with zeros only inplace=false
|
||||
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext));
|
||||
std::unique_ptr<NDArray> matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace());
|
||||
std::unique_ptr<NDArray> lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context));
|
||||
|
||||
for (int e = 0; e < totalCount; e++) {
|
||||
|
||||
|
@ -346,14 +338,13 @@ template <typename T>
|
|||
}
|
||||
|
||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int logdetFunctor_(NDArray* input, NDArray* output) {
|
||||
int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) {
|
||||
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||
int res = cholesky_<T>(input, tempOutput.get(), false);
|
||||
int res = cholesky_<T>(context, input, tempOutput.get(), false);
|
||||
if (res != ND4J_STATUS_OK)
|
||||
return res;
|
||||
auto n = input->sizeAt(-1);
|
||||
|
@ -372,7 +363,7 @@ template <typename T>
|
|||
}
|
||||
|
||||
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -907,6 +907,8 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
|||
|
||||
/*** max ***/
|
||||
case 0: {
|
||||
coord2 = hstart;
|
||||
coord3 = hend;
|
||||
|
||||
T max = -DataTypeUtils::max<T>();
|
||||
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
||||
|
|
|
@ -31,8 +31,6 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext();
|
||||
|
||||
// template <typename T>
|
||||
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||
// if (theFirst != theSecond) {
|
||||
|
@ -198,36 +196,33 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->setIdentity();
|
||||
|
||||
if (inputMatrix->isIdentityMatrix()) return;
|
||||
|
||||
auto stream = defaultContext->getCudaStream();
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
// invert main diagonal
|
||||
upvertKernel<T> << < 1, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
upvertKernel<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invert the second diagonal
|
||||
invertKernelLow<T> << < 1, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertKernelLow<T><<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertLowKernel<T><<< n, n, 512, *stream >> >
|
||||
(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertLowKernel<T><<<n, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
int n = inputMatrix->rows();
|
||||
invertedMatrix->setIdentity();
|
||||
auto stream = defaultContext->getCudaStream();
|
||||
auto stream = context->getCudaStream();
|
||||
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||
return;
|
||||
}
|
||||
|
@ -237,13 +232,12 @@ namespace helpers {
|
|||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertedMatrix->tickWriteDevice();
|
||||
invertedMatrix->printIndexedBuffer("Step1 UP inversion");
|
||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),
|
||||
inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertUpKernel<T><<<n, n, 512, *stream >>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) {
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix});
|
||||
}
|
||||
|
||||
|
@ -392,7 +386,6 @@ namespace helpers {
|
|||
auto n = input->rows();
|
||||
cusolverDnHandle_t cusolverH = nullptr;
|
||||
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||
defaultContext = context;
|
||||
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||
}
|
||||
|
@ -528,24 +521,19 @@ namespace helpers {
|
|||
input->tickWriteDevice();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_,
|
||||
(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation),
|
||||
FLOAT_NATIVE);
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE);
|
||||
|
||||
template<typename T>
|
||||
static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
std::vector<int> dims();
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2, input->rankOf() - 1});
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||
// DataType dtype = input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
defaultContext = context;
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(),
|
||||
defaultContext); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT<T>(), context); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -554,8 +542,7 @@ namespace helpers {
|
|||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
|
@ -578,7 +565,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -586,19 +572,16 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
Nd4jLong n = input->sizeAt(-1);
|
||||
Nd4jLong n2 = n * n;
|
||||
std::vector<int> dims();
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2, input->rankOf() - 1});
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||
DataType dtype = input->dataType();
|
||||
if (dtype != DataType::DOUBLE)
|
||||
dtype = DataType::FLOAT32;
|
||||
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype,
|
||||
defaultContext); //, block.getWorkspace());
|
||||
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace());
|
||||
auto det = NDArrayFactory::create<T>(1);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
@ -607,8 +590,7 @@ namespace helpers {
|
|||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
|
@ -620,8 +602,7 @@ namespace helpers {
|
|||
auto inputBuf = reinterpret_cast<T *>(matrix.specialBuffer());
|
||||
auto outputBuf = reinterpret_cast<T *>(output->specialBuffer()) + offset;
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
determinantLogKernel<T> << < launchDims.x, launchDims.y, launchDims.z, *stream >> >
|
||||
(inputBuf, outputBuf, n);
|
||||
determinantLogKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, outputBuf, n);
|
||||
// else
|
||||
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
}
|
||||
|
@ -633,7 +614,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -696,17 +676,16 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
auto n = input->sizeAt(-1);
|
||||
auto n2 = n * n;
|
||||
auto dtype = DataTypeUtils::fromT<T>(); //input->dataType();
|
||||
// if (dtype != DataType::DOUBLE)
|
||||
// dtype = DataType::FLOAT32;
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext);
|
||||
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context);
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),
|
||||
{input->rankOf() - 2,
|
||||
input->rankOf() - 1});
|
||||
|
@ -716,20 +695,17 @@ namespace helpers {
|
|||
auto stream = context->getCudaStream();
|
||||
|
||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||
fillMatrix<T, T> << < 1, n2, 1024, *stream >> >
|
||||
(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(),
|
||||
i * n2, n);
|
||||
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||
matrix.tickWriteDevice();
|
||||
compound.assign(matrix);
|
||||
lup_<T>(context, &compound, nullptr, nullptr);
|
||||
fillLowerUpperKernel<T> << < n, n, 1024, *stream >> >
|
||||
(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||
matrix.assign(0);
|
||||
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
||||
invertUpperMatrix(context, &upper, &matrix); // U^{-1}
|
||||
matrix.tickWriteDevice();
|
||||
// matrix.printIndexedBuffer("Upper Inverted");
|
||||
compound.assign(0);
|
||||
invertLowerMatrix(&lower, &compound); // L{-1}
|
||||
invertLowerMatrix(context, &lower, &compound); // L{-1}
|
||||
compound.tickWriteDevice();
|
||||
// compound.printIndexedBuffer("Lower Inverted");
|
||||
// matrix.tickWriteDevice();
|
||||
|
@ -737,15 +713,12 @@ namespace helpers {
|
|||
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||
upper.tickWriteDevice();
|
||||
// upper.printIndexedBuffer("Full inverted");
|
||||
returnMatrix<T> << < 1, n2, 1024, *stream >> >
|
||||
(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(),
|
||||
i * n2, n);
|
||||
returnMatrix<T> <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
@ -788,7 +761,6 @@ namespace helpers {
|
|||
int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
if (!inplace)
|
||||
output->assign(input);
|
||||
defaultContext = context;
|
||||
std::unique_ptr<NDArray> tempOutput(output->dup());
|
||||
cusolverDnHandle_t handle = nullptr;
|
||||
auto n = input->sizeAt(-1);
|
||||
|
@ -868,7 +840,6 @@ namespace helpers {
|
|||
|
||||
// template <typename T>
|
||||
int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
if (input->dataType() == DataType::DOUBLE)
|
||||
cholesky__<double>(context, input, output, inplace);
|
||||
|
@ -876,8 +847,7 @@ namespace helpers {
|
|||
cholesky__<float>(context, input, output, inplace);
|
||||
else {
|
||||
std::unique_ptr<NDArray> tempOutput(
|
||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32,
|
||||
defaultContext));
|
||||
NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context));
|
||||
tempOutput->assign(input);
|
||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||
output->assign(tempOutput.get());
|
||||
|
@ -888,7 +858,6 @@ namespace helpers {
|
|||
|
||||
int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) {
|
||||
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||
defaultContext = context;
|
||||
return cholesky_(context, input, output, inplace);
|
||||
}
|
||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||
|
@ -927,7 +896,6 @@ namespace helpers {
|
|||
|
||||
template<typename T>
|
||||
int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||
auto stream = context->getCudaStream();
|
||||
|
@ -957,7 +925,6 @@ namespace helpers {
|
|||
}
|
||||
|
||||
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
defaultContext = context;
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include "testlayers.h"
|
||||
#include <graph/Stash.h>
|
||||
#include <FlatUtils.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
class FlatUtilsTests : public testing::Test {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_float_serde_1) {
|
||||
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_int_serde_1) {
|
||||
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
||||
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
|
@ -24,7 +24,6 @@
|
|||
#include "testlayers.h"
|
||||
#include <graph/Stash.h>
|
||||
|
||||
using namespace nd4j;
|
||||
using namespace nd4j;
|
||||
|
||||
class StringTests : public testing::Test {
|
||||
|
@ -91,4 +90,4 @@ TEST_F(StringTests, Basic_dup_1) {
|
|||
ASSERT_EQ(f, z1);
|
||||
|
||||
delete dup;
|
||||
}
|
||||
}
|
|
@ -31,10 +31,35 @@
|
|||
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
|
||||
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
|
||||
change will run into a compilation error. This can be removed after a few weeks.-->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-antrun-plugin</artifactId>
|
||||
<version>1.8</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<target>
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
|
||||
</target>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protoc-jar-maven-plugin</artifactId>
|
||||
<version>3.5.1.1</version>
|
||||
<version>3.8.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>tensorflow</id>
|
||||
|
@ -43,30 +68,14 @@
|
|||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<type>java-shaded</type>
|
||||
<protocVersion>3.5.1</protocVersion>
|
||||
<protocVersion>3.8.0</protocVersion>
|
||||
<extension>.proto</extension>
|
||||
<includeDirectories>
|
||||
<include>src/main/protobuf/tf</include>
|
||||
<include>src/main/protobuf/onnx</include>
|
||||
</includeDirectories>
|
||||
<inputDirectories>
|
||||
<include>src/main/protobuf/tf/tensorflow</include>
|
||||
</inputDirectories>
|
||||
<addSources>main</addSources>
|
||||
<cleanOutputFolder>false</cleanOutputFolder>
|
||||
<outputDirectory>src/main/java/</outputDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>onnx</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<type>java-shaded</type>
|
||||
<extension>.proto3</extension>
|
||||
<protocVersion>3.5.1</protocVersion>
|
||||
<inputDirectories>
|
||||
<include>src/main/protobuf/onnx</include>
|
||||
</inputDirectories>
|
||||
<addSources>main</addSources>
|
||||
|
@ -76,6 +85,32 @@
|
|||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>com.google.code.maven-replacer-plugin</groupId>
|
||||
<artifactId>replacer</artifactId>
|
||||
<version>1.5.3</version>
|
||||
<configuration>
|
||||
<includes>
|
||||
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
|
||||
<include>${project.build.sourceDirectory}/tensorflow/**</include>
|
||||
<include>${project.build.sourceDirectory}/onnx/**</include>
|
||||
</includes>
|
||||
<token>com.google.protobuf.</token>
|
||||
<value>org.nd4j.shade.protobuf.</value>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>replace-imports</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>replace</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
|
@ -148,20 +183,15 @@
|
|||
<version>${flatbuffers.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Note that this is shaded flatbuffers, see the protoc declaration above
|
||||
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging
|
||||
their own older protobuf versions-->
|
||||
<!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
|
||||
their own older (incompatible) protobuf versions-->
|
||||
<dependency>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protobuf-java-shaded-351</artifactId>
|
||||
<version>0.9</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protobuf-java-util-shaded-351</artifactId>
|
||||
<version>0.9</version>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>protobuf</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.objenesis</groupId>
|
||||
<artifactId>objenesis</artifactId>
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
|
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
|
|||
|
||||
/**
|
||||
* Initialize the function from the given
|
||||
* {@link onnx.OnnxProto3.NodeProto}
|
||||
* {@link onnx.Onnx.NodeProto}
|
||||
* @param node
|
||||
*/
|
||||
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
this.sameDiff = sameDiff;
|
||||
setInstanceId();
|
||||
initFromOnnx(node, sameDiff, attributesForNode, graph);
|
||||
|
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
|
|||
|
||||
/**
|
||||
* Iniitialize the function from the given
|
||||
* {@link onnx.OnnxProto3.NodeProto}
|
||||
* {@link onnx.Onnx.NodeProto}
|
||||
* @param node
|
||||
* @param initWith
|
||||
* @param attributesForNode
|
||||
* @param graph
|
||||
*/
|
||||
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph);
|
||||
public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
|
|||
import java.util.Objects;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.descriptors.tensorflow;
|
||||
|
||||
import com.github.os72.protobuf351.TextFormat;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.tensorflow.framework.OpDef;
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import com.github.os72.protobuf351.TextFormat;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper.onnx;
|
||||
|
||||
import com.github.os72.protobuf351.ByteString;
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import com.google.common.primitives.Floats;
|
||||
import com.google.common.primitives.Ints;
|
||||
import com.google.common.primitives.Longs;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -52,7 +52,7 @@ import java.util.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> {
|
||||
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
|
||||
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
||||
|
||||
|
||||
|
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
@Override
|
||||
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
||||
try {
|
||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString() + "\n");
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
* @param node
|
||||
* @param graph
|
||||
*/
|
||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
|
||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
|
||||
val properties = on.mappingsForFunction();
|
||||
val tfProperties = properties.get(mappedTfName);
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
|
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
|
||||
public boolean isOpIgnoreException(Onnx.NodeProto node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
|
||||
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
|
||||
return function.opName();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
/**
|
||||
|
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
|
||||
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
|
||||
for(int i = 0; i < graph.getNodeCount(); i++) {
|
||||
val node = graph.getNode(i);
|
||||
if(node.getName().equals(name))
|
||||
|
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
|
||||
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getControlDependencies(OnnxProto3.NodeProto node) {
|
||||
public List<String> getControlDependencies(Onnx.NodeProto node) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
||||
try {
|
||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString());
|
||||
}
|
||||
|
||||
|
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
|
||||
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
|
||||
/**
|
||||
* Need to figure out why
|
||||
* gpu_0/conv1_1 isn't present in VGG
|
||||
*/
|
||||
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
|
||||
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
|
||||
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
||||
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
||||
}
|
||||
|
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
|
||||
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
|
||||
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.
|
||||
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
|
||||
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
|
||||
newBuilder()
|
||||
.setDimValue(-1)
|
||||
.build();
|
||||
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder()
|
||||
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
|
||||
.setShape(
|
||||
OnnxProto3.TensorShapeProto.newBuilder()
|
||||
Onnx.TensorShapeProto.newBuilder()
|
||||
.addDim(dim)
|
||||
.addDim(dim).build())
|
||||
.build();
|
||||
|
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
@Override
|
||||
public Message.Builder getNewGraphBuilder() {
|
||||
return OnnxProto3.GraphProto.newBuilder();
|
||||
return Onnx.GraphProto.newBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
||||
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
||||
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState,
|
||||
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride,
|
||||
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) {
|
||||
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
|
||||
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
|
||||
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
|
||||
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
||||
if(differentialFunction == null) {
|
||||
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
||||
|
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) {
|
||||
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
|
||||
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) {
|
||||
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING;
|
||||
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
|
||||
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
|
||||
}
|
||||
|
||||
|
||||
|
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
* @param dataType the data type to convert
|
||||
* @return the nd4j type for the onnx type
|
||||
*/
|
||||
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
|
||||
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
|
||||
switch (dataType) {
|
||||
case DOUBLE: return DataType.DOUBLE;
|
||||
case FLOAT: return DataType.FLOAT;
|
||||
|
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
|
||||
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
|
||||
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||
if(attributeProto.getName().equals(key)) {
|
||||
return attributeProto.getS().toString();
|
||||
}
|
||||
|
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
|
||||
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
|
||||
return Longs.toArray(attributeProto.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
|
||||
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
|
||||
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
|
||||
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
|
||||
DataType type = dataTypeForTensor(tensorProto, 0);
|
||||
if(!tensorProto.isInitialized()) {
|
||||
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
||||
}
|
||||
|
||||
OnnxProto3.TensorProto tensor = null;
|
||||
Onnx.TensorProto tensor = null;
|
||||
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
||||
val initializer = graph.getInitializer(i);
|
||||
if(initializer.getName().equals(tensorName)) {
|
||||
|
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
return arr;
|
||||
}
|
||||
|
||||
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
|
||||
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
|
||||
if(tensor == null)
|
||||
return null;
|
||||
|
||||
|
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) {
|
||||
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
||||
int dimCount = tensorProto.getShape().getDimCount();
|
||||
if(dimCount >= 2)
|
||||
|
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
/**
|
||||
* Get the shape from a tensor proto.
|
||||
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)}
|
||||
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
|
||||
* @param tensorProto the tensor to get the shape from
|
||||
* @return
|
||||
*/
|
||||
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
|
||||
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
||||
int dimCount = tensorProto.getDimsCount();
|
||||
if(dimCount >= 2)
|
||||
|
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
|
||||
public String getInputFromNode(Onnx.NodeProto node, int index) {
|
||||
return node.getInput(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
|
||||
public int numInputsFor(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getInputCount();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
|
||||
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
|
||||
return Longs.toArray(attr.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
|
||||
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>();
|
||||
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
|
||||
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
|
||||
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
||||
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||
proto.put(attributeProto.getName(),attributeProto);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName(OnnxProto3.NodeProto nodeProto) {
|
||||
public String getName(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType().contains("Var");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldSkip(OnnxProto3.NodeProto opType) {
|
||||
public boolean shouldSkip(Onnx.NodeProto opType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean hasShape(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getShape(OnnxProto3.NodeProto nodeProto) {
|
||||
public long[] getShape(Onnx.NodeProto nodeProto) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
|
||||
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOpType(OnnxProto3.NodeProto nodeProto) {
|
||||
public String getOpType(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
|
||||
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
|
||||
return graphProto.getNodeList();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper.tf;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import com.google.common.primitives.Floats;
|
||||
import com.google.common.primitives.Ints;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package org.nd4j.imports.graphmapper.tf.tensors;
|
||||
|
||||
import com.github.os72.protobuf351.Descriptors;
|
||||
import org.nd4j.shade.protobuf.Descriptors;
|
||||
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
|
||||
import org.bytedeco.javacpp.indexer.HalfIndexer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Data;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
|
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
if (!attributesForNode.containsKey("axes")) {
|
||||
this.dimensions = new int[] { Integer.MAX_VALUE };
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
|
|||
import com.google.common.primitives.Longs;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
|
|||
import lombok.Builder;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
|
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val dilations = attributesForNode.get("dilations");
|
||||
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val aAlpha = attributesForNode.get("alpha");
|
||||
val aBeta = attributesForNode.get("beta");
|
||||
val aBias = attributesForNode.get("bias");
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val isSameNode = paddingVal.equals("SAME");
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
val padding = attributesForNode.get("pads").getIntsList();
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
|
||||
|
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||
|
|
|
@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
|
|||
import com.google.common.primitives.Longs;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
|
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
|
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val shape = new OnnxGraphMapper().getShape(node);
|
||||
this.shape = shape;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxMlProto3;
|
||||
import onnx.OnnxMl;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
|
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import com.google.common.primitives.Ints;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
|
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
if (!attributesForNode.containsKey("perm")) {
|
||||
|
||||
} else
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
//No op
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
|
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.random.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.random.impl;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package onnx;
|
||||
import "onnx.proto3";
|
||||
import "onnx.proto";
|
||||
|
||||
//
|
||||
// This file contains the proto definitions for OperatorSetProto and
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
|
|
@ -732,4 +732,20 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
fail("Failed datatypes: " + failed.toString());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMaxPool2Dbp_1() {
|
||||
val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN);
|
||||
val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN);
|
||||
val z = Nd4j.create(DataType.HALF, 2,3,16,16);
|
||||
|
||||
val op = DynamicCustomOp.builder("maxpool2d_bp")
|
||||
.addInputs(x, y)
|
||||
.addOutputs(z)
|
||||
.addIntegerArguments(2, 2, 2, 2, 8,8, 1,1,1, 0,0)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
Nd4j.getExecutioner().commit();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
<packaging>pom</packaging>
|
||||
<modules>
|
||||
<module>jackson</module>
|
||||
<module>protobuf</module>
|
||||
</modules>
|
||||
|
||||
<properties>
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>nd4j-shade</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>protobuf</artifactId>
|
||||
|
||||
<properties>
|
||||
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>3.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java-util</artifactId>
|
||||
<version>3.8.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>custom-lifecycle</id>
|
||||
|
||||
<activation>
|
||||
<property><name>!skip.custom.lifecycle</name></property>
|
||||
</activation>
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.portals.jetspeed-2</groupId>
|
||||
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
|
||||
<version>2.3.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>compile-and-pack</id>
|
||||
<phase>compile</phase>
|
||||
<goals>
|
||||
<goal>mvn</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.maven.shared</groupId>
|
||||
<artifactId>maven-invoker</artifactId>
|
||||
<version>2.2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<configuration>
|
||||
<targets combine.children="merge">
|
||||
|
||||
<target>
|
||||
<id>create-shaded-jars</id>
|
||||
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
|
||||
<goals>clean,compile,package</goals>
|
||||
<properties>
|
||||
<skip.custom.lifecycle>true</skip.custom.lifecycle>
|
||||
</properties>
|
||||
</target>
|
||||
|
||||
</targets>
|
||||
<defaultTarget>create-shaded-jars</defaultTarget>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
|
||||
<plugin>
|
||||
<groupId>com.lewisd</groupId>
|
||||
<artifactId>lint-maven-plugin</artifactId>
|
||||
<version>0.0.11</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>pom-lint</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
|
||||
<!--
|
||||
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
|
||||
including this module (org.nd4j.protobuf) as a dependency.
|
||||
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
|
||||
-->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>${maven-shade-plugin.version}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||
<resource>reference.conf</resource>
|
||||
</transformer>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
|
||||
<configuration>
|
||||
<!--
|
||||
Important configuration options here:
|
||||
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
|
||||
original dependencies will be shaded, AND still included as transitive deps
|
||||
in the final POM. This is not what we want.
|
||||
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
|
||||
the original un-shaded JAR being separate. With this being set to false,
|
||||
the original JAR will be modified, and no extra jar will be produced.
|
||||
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
|
||||
to direct dependencies. Without this, we need to manually manage the transitive
|
||||
dependencies of the shaded artifacts.
|
||||
|
||||
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
|
||||
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
|
||||
-->
|
||||
<shadedArtifactAttached>false</shadedArtifactAttached>
|
||||
<createDependencyReducedPom>true</createDependencyReducedPom>
|
||||
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
|
||||
|
||||
<artifactSet>
|
||||
<includes>
|
||||
<include>com.google.protobuf:*</include>
|
||||
<include>com.google.protobuf.*:*</include>
|
||||
</includes>
|
||||
</artifactSet>
|
||||
|
||||
<relocations>
|
||||
<!-- Protobuf dependencies -->
|
||||
<relocation>
|
||||
<pattern>com.google.protobuf</pattern>
|
||||
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
|
||||
</relocation>
|
||||
</relocations>
|
||||
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<configuration>
|
||||
<forceCreation>true</forceCreation>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>empty-javadoc-jar</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<classifier>javadoc</classifier>
|
||||
<classesDirectory>${basedir}/javadoc</classesDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>empty-sources-jar</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<classifier>sources</classifier>
|
||||
<classesDirectory>${basedir}/src</classesDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-dependency-plugin</artifactId>
|
||||
<version>3.0.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>unpack</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>unpack</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<artifactItems>
|
||||
<artifactItem>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>protobuf</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<type>jar</type>
|
||||
<overWrite>false</overWrite>
|
||||
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
|
||||
<includes>**/*.class,**/*.xml</includes>
|
||||
</artifactItem>
|
||||
</artifactItems>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.javacpp.indexer.*;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||
|
||||
import com.github.os72.protobuf351.ByteString;
|
||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
|
|||
|
||||
/**
|
||||
* Convert a json string written out
|
||||
* by {@link com.github.os72.protobuf351.util.JsonFormat}
|
||||
* by {@link org.nd4j.shade.protobuf.util.JsonFormat}
|
||||
* to a {@link org.bytedeco.tensorflow.ConfigProto}
|
||||
* @param json the json to read
|
||||
* @return the config proto to use
|
||||
|
|
Loading…
Reference in New Issue