diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 99f8aeff0..00c0cf7d6 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; List[] closeAtEndIteraton = (List[])new List[topologicalOrder.length]; MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + Throwable t = null; try { for (int i = 0; i <= stopIndex; i++) { GraphVertex current = vertices[topologicalOrder[i]]; @@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) .build(); - if(detachedInputs){ + if (detachedInputs) { //Sometimes (like: external errors use cases) we don't want the activations/inputs to be // in a workspace workspaceMgr.setScopedOutFor(ArrayType.INPUT); workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS); } else { //Don't leverage out of async MultiDataSetIterator workspaces - if(features[0].isAttached()){ + if (features[0].isAttached()) { workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId()); } } @@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (ArrayUtils.contains(layerIndexes, vIdx)) { isRequiredOutput = true; - if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ + if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { //Place activations in user-specified workspace origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS); @@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //Open the relevant workspace for the activations. //Note that this will be closed only once the current vertex's activations have been consumed MemoryWorkspace wsActivations = null; - if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS + if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS); openActivationsWorkspaces.put(wsActivations, workspaceMgr); } @@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //Note that because we're opening activation workspaces not in any defined order (i.e., workspace // use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we // close these workspaces, the "current" workspace may be set to the incorrect one - if(wsActivations != null ) + if (wsActivations != null) wsActivations.setPreviousWorkspace(initialWorkspace); int closeableAt = vertexOutputsFullyConsumedByStep[vIdx]; - if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) { + if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) { if (closeAtEndIteraton[closeableAt] == null) { closeAtEndIteraton[closeableAt] = new ArrayList<>(); } @@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { out = features[vIdx]; } else { - if(fwdPassType == FwdPassType.STANDARD){ + if (fwdPassType == FwdPassType.STANDARD) { //Standard feed-forward case out = current.doForward(train, workspaceMgr); - } else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ + } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { if (current.hasLayer()) { //Layer INDArray input = current.getInputs()[0]; Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); - } else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){ - RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying()); + } else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) { + RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying()); out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); } else if (l instanceof MultiLayerNetwork) { out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input)); @@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)"); } - if(inputsTo != null) { //Output vertices may not input to any other vertices + if (inputsTo != null) { //Output vertices may not input to any other vertices for (VertexIndices v : inputsTo) { //Note that we don't have to do anything special here: the activations are always detached in // this method @@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } - if(clearLayerInputs) { + if (clearLayerInputs) { current.clear(); } - if(isRequiredOutput){ + if (isRequiredOutput) { outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out; - if(origWSAct != null){ + if (origWSAct != null) { //Reset the configuration, as we may reuse this workspace manager... workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf); } @@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //Close any activations workspaces that we no longer require //Note that activations workspaces can be closed only once the corresponding output activations have // been fully consumed - if(closeAtEndIteraton[i] != null){ - for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ + if (closeAtEndIteraton[i] != null) { + for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) { wsAct.close(); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); freeWorkspaceManagers.add(canNowReuse); } } } + } catch (Throwable t2){ + t = t2; } finally { //Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Though if stopIndex < numLayers, some might still be open @@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //Edge case here: seems that scoping out can increase the tagScope of the current WS //and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient // number of times to actually close it, in all cases - ws.close(); + try{ + ws.close(); + } catch (Throwable t2){ + if(t != null){ + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); @@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { boolean traceLog = log.isTraceEnabled(); - try{ - for(int i=topologicalOrder.length-1; i>= 0; i--){ + Throwable t = null; + try { + for (int i = topologicalOrder.length - 1; i >= 0; i--) { boolean hitFrozen = false; GraphVertex current = vertices[topologicalOrder[i]]; int vIdx = current.getVertexIndex(); String vertexName = current.getVertexName(); - if(traceLog){ + if (traceLog) { log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); } //FIXME: make the frozen vertex feature extraction more flexible - if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){ + if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) { hitFrozen = true; } - if (current.isInputVertex() || hitFrozen){ + if (current.isInputVertex() || hitFrozen) { //Close any activation gradient workspaces that we no longer require //Note that activation gradient workspaces can be closed only once the corresponding activations // gradients have been fully consumed - if(closeAtEndIteraton[i] != null){ - for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ + if (closeAtEndIteraton[i] != null) { + for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) { wsAct.close(); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); freeWorkspaceManagers.add(canNowReuse); @@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { wsActivationGrads.setPreviousWorkspace(initialWorkspace); int closeableAt = vertexActGradsFullyConsumedByStep[vIdx]; - if(closeableAt >= 0) { + if (closeableAt >= 0) { if (closeAtEndIteraton[closeableAt] == null) { closeAtEndIteraton[closeableAt] = new ArrayList<>(); } @@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { Pair pair; INDArray[] epsilons; - try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ + try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) { pair = current.doBackward(truncatedBPTT, workspaceMgr); epsilons = pair.getSecond(); //Validate workspace location for the activation gradients: //validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){ for (INDArray epsilon : epsilons) { - if(epsilon != null) { + if (epsilon != null) { //May be null for EmbeddingLayer, etc validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop"); } @@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { tempList.addFirst(new Triple<>(newName, entry.getValue(), g.flatteningOrderForVariable(origName))); } - for (Triple t : tempList) - gradients.addFirst(t); + for (Triple triple : tempList) + gradients.addFirst(triple); } //Close any activation gradient workspaces that we no longer require //Note that activation gradient workspaces can be closed only once the corresponding activations // gradients have been fully consumed - if(closeAtEndIteraton[i] != null){ - for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ + if (closeAtEndIteraton[i] != null) { + for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) { wsAct.close(); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); freeWorkspaceManagers.add(canNowReuse); @@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { closeAtEndIteraton[i] = null; } - if(traceLog){ + if (traceLog) { log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); } } - + } catch (Throwable t2){ + t = t2; } finally { //Close all open workspaces... usually this list will be empty, but not if an exception is thrown for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){ - ws.close(); + try{ + ws.close(); + } catch (Throwable t2){ + if(t != null){ + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); } //Now, add the gradients in the order we need them in for flattening (same as params order) Gradient gradient = new DefaultGradient(flattenedGradients); - for (Triple t : gradients) { - gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird()); + for (Triple tr : gradients) { + gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird()); } this.gradient = gradient; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 731ca398b..dd495a620 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura boolean traceLog = log.isTraceEnabled(); + Throwable t = null; try { for (int i = 0; i <= layerIndex; i++) { LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); - if(traceLog){ + if (traceLog) { log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Hence: put inputs in working memory - if(i == 0 && wsm != WorkspaceMode.NONE){ + if (i == 0 && wsm != WorkspaceMode.NONE) { mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); } @@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura temp.setPreviousWorkspace(initialWorkspace); - if(i == 0 && input.isAttached()){ + if (i == 0 && input.isAttached()) { //Don't leverage out of async DataSetIterator workspaces mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } @@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)"); } - if ( i == layerIndex ) { - if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ + if (i == layerIndex) { + if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { //Place activations in user-specified workspace mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration()); } else { @@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura } } - if(fwdPassType == FwdPassType.STANDARD){ + if (fwdPassType == FwdPassType.STANDARD) { //Standard feed-forward case input = layers[i].activate(input, train, mgr); - } else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ + } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { //rnnTimeStep case if (layers[i] instanceof RecurrentLayer) { input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); - } else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ - RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()); + } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { + RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying()); input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr); } else if (layers[i] instanceof MultiLayerNetwork) { input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input)); @@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura //Validation: Exception if invalid (bad layer implementation) validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)"); - if(wsActCloseNext != null){ + if (wsActCloseNext != null) { wsActCloseNext.close(); } wsActCloseNext = temp; temp = null; } - if(traceLog){ + if (traceLog) { log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Hence: put inputs in working memory -> set back to default for next use of workspace mgr - if(i == 0 && wsm != WorkspaceMode.NONE){ + if (i == 0 && wsm != WorkspaceMode.NONE) { mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS } } - + } catch (Throwable t2){ + t = t2; } finally { if(wsActCloseNext != null){ - wsActCloseNext.close(); + try { + wsActCloseNext.close(); + } catch (Throwable t2){ + if(t != null){ + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } if(temp != null){ //Should only be non-null on exception while(temp.isScopeActive()){ //For safety, should never occur in theory: a single close() call may not be sufficient, if // workspace scope was borrowed and not properly closed when exception occurred - temp.close(); + try{ + temp.close(); + } catch (Throwable t2){ + if(t != null){ + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } } @@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura boolean traceLog = log.isTraceEnabled(); + Throwable t = null; try { for (int i = layers.length - 1; i >= 0; i--) { if (layers[i] instanceof FrozenLayer) { break; } - if(traceLog){ + if (traceLog) { log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName()); } @@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); - try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ + try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) { //Note that because we're opening activation workspaces not in a simple nested order, we'll manually // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" @@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer - if(!tbptt){ + if (!tbptt) { //Standard case currPair = layers[i].backpropGradient(eps, workspaceMgr); } else { @@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura } } - if(currPair.getSecond() != null) { + if (currPair.getSecond() != null) { //Edge case: may be null for Embedding layer, for example validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, false, "Backprop"); @@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura currPair = new Pair<>(currPair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(i) .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); - if (i > 0 && currPair.getSecond() != null){ + if (i > 0 && currPair.getSecond() != null) { validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, true, "Backprop"); } } - if(i == 0 ){ - if(returnInputActGrad && currPair.getSecond() != null){ + if (i == 0) { + if (returnInputActGrad && currPair.getSecond() != null) { currPair.setSecond(currPair.getSecond().detach()); } else { currPair.setSecond(null); } } - if(wsActGradCloseNext != null){ + if (wsActGradCloseNext != null) { wsActGradCloseNext.close(); } wsActGradCloseNext = wsActGradTemp; wsActGradTemp = null; } - if(traceLog){ + if (traceLog) { log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); } } + } catch (Throwable thr ){ + t = thr; } finally { if(wsActGradCloseNext != null){ - wsActGradCloseNext.close(); + try { + wsActGradCloseNext.close(); + } catch (Throwable t2){ + if(t != null){ + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } - if(wsActGradTemp != null){ + if(wsActGradTemp != null) { //Should only be non-null on exception - wsActGradTemp.close(); + try { + wsActGradTemp.close(); + } catch (Throwable t2) { + if (t != null) { + log.error("Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } } Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); } diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 1404afc96..fdbcae49f 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -476,19 +476,36 @@ std::vector NDArray::getShapeInfoAsVector() { //////////////////////////////////////////////////////////////////////// std::vector NDArray::asByteVector() { - std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); - if (this->isView()) { - auto tmp = this->dup(this->ordering()); - memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + if (isS()) { + // string data type requires special treatment + syncToHost(); + auto numWords = this->lengthOf(); + auto offsetsBuffer = this->bufferAsT(); + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); + auto dataLength = offsetsBuffer[numWords]; + std::vector result(headerLength + dataLength); - delete tmp; + memcpy(result.data(), getBuffer(), headerLength + dataLength); + + return result; + } else { + // all other types are linear + std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); + + if (this->isView()) { + auto tmp = this->dup(this->ordering()); + syncToHost(); + memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + + delete tmp; + } else { + syncToHost(); + memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + } + return result; } - else { - memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); - } - return result; } ////////////////////////////////////////////////////////////////////////// @@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const { ////////////////////////////////////////////////////////////////////////// template T* NDArray::bufferAsT() const { - if (isS()) - throw std::runtime_error("You can't use this method on String array"); - + // FIXME: do we REALLY want sync here? syncToHost(); return reinterpret_cast(getBuffer()); @@ -3202,20 +3217,39 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const { } else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo())) return false; - NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 + if (isS()) { + // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length + for (int e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - ExtraArguments extras({eps}); + if (s1 != s2) + return false; + } - NDArray::prepareSpecialUse({&tmp}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); - NDArray::registerSpecialUse({&tmp}, {this, other}); + return true; + } else { + // regular numeric types + NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 - synchronize("NDArray::equalsTo"); + ExtraArguments extras({eps}); - if (tmp.e(0) > 0) - return false; + NDArray::prepareSpecialUse({&tmp}, {this, other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), + getSpecialBuffer(), getSpecialShapeInfo(), + extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), + other->getShapeInfo(), other->getSpecialBuffer(), + other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo()); + NDArray::registerSpecialUse({&tmp}, {this, other}); - return true; + synchronize("NDArray::equalsTo"); + + if (tmp.e(0) > 0) + return false; + + return true; + } } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index b5e7d9bf2..6f97bc024 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -54,6 +54,7 @@ #include #include #include +#include namespace nd4j{ namespace graph { @@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) continue; - NDArray* array = var->getNDArray(); - auto byteVector = array->asByteVector(); + auto array = var->getNDArray(); - auto fBuffer = builder.CreateVector(byteVector); - auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); - - auto bo = static_cast(BitwiseUtils::asByteOrder()); - - auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast(array->dataType()), bo); + auto fArray = FlatUtils::toFlatArray(builder, *array); auto fName = builder.CreateString(*(var->getName())); auto id = CreateIntPair(builder, var->id(), var->index()); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 5c6dadbaf..e75aa422c 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -866,9 +866,10 @@ void initializeFunctions(Nd4jPointer *functions) { Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer pointer; // cudaHostAllocMapped |cudaHostAllocPortable - cudaError_t res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize, cudaHostAllocDefault); + auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize, cudaHostAllocDefault); if (res != 0) - pointer = 0L; + throw nd4j::cuda_exception::build("cudaHostAlloc(...) failed", res); + return pointer; } @@ -884,7 +885,7 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { Nd4jPointer pointer; auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize); if (res != 0) - pointer = 0L; + throw nd4j::cuda_exception::build("cudaMalloc(...) failed", res); return pointer; } @@ -894,9 +895,9 @@ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { * @param pointer pointer that'll be freed */ int freeHost(Nd4jPointer pointer) { - cudaError_t res = cudaFreeHost(reinterpret_cast(pointer)); + auto res = cudaFreeHost(reinterpret_cast(pointer)); if (res != 0) - pointer = 0L; + throw nd4j::cuda_exception::build("cudaFreeHost(...) failed", res); return 1L; } @@ -907,9 +908,10 @@ int freeHost(Nd4jPointer pointer) { * @param ptrToDeviceId pointer to deviceId. */ int freeDevice(Nd4jPointer pointer, int deviceId) { - cudaError_t res = cudaFree(reinterpret_cast(pointer)); + auto res = cudaFree(reinterpret_cast(pointer)); if (res != 0) - pointer = 0L; + throw nd4j::cuda_exception::build("cudaFree(...) failed", res); + return 1L; } @@ -934,7 +936,7 @@ Nd4jPointer createStream() { auto stream = new cudaStream_t(); auto dZ = cudaStreamCreate(stream); if (dZ != 0) - throw std::runtime_error("cudaStreamCreate(...) failed"); + throw nd4j::cuda_exception::build("cudaStreamCreate(...) failed", dZ); return stream; } @@ -944,23 +946,21 @@ Nd4jPointer createEvent() { CHECK_ALLOC(nativeEvent, "Failed to allocate new CUDA event buffer", sizeof(cudaEvent_t)); - cudaError_t dZ = cudaEventCreateWithFlags(reinterpret_cast(&nativeEvent), cudaEventDisableTiming); - checkCudaErrors(dZ); + auto dZ = cudaEventCreateWithFlags(reinterpret_cast(&nativeEvent), cudaEventDisableTiming); if (dZ != 0) - throw std::runtime_error("cudaEventCreateWithFlags(...) failed"); + throw nd4j::cuda_exception::build("cudaEventCreateWithFlags(...) failed", dZ); return nativeEvent; } int registerEvent(Nd4jPointer event, Nd4jPointer stream) { - cudaEvent_t *pEvent = reinterpret_cast(&event); - cudaStream_t *pStream = reinterpret_cast(stream); + auto pEvent = reinterpret_cast(&event); + auto pStream = reinterpret_cast(stream); - cudaError_t dZ = cudaEventRecord(*pEvent, *pStream); - checkCudaErrors(dZ); + auto dZ = cudaEventRecord(*pEvent, *pStream); if (dZ != 0) - throw std::runtime_error("cudaEventRecord(...) failed"); + throw nd4j::cuda_exception::build("cudaEventRecord(...) failed", dZ); return 1; } @@ -1065,53 +1065,48 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j } int memsetSync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaError_t dZ = cudaMemset(reinterpret_cast(dst), value, static_cast(size)); - checkCudaErrors(dZ); + auto dZ = cudaMemset(reinterpret_cast(dst), value, static_cast(size)); if (dZ != 0) - throw std::runtime_error("cudaMemset(...) failed"); + throw nd4j::cuda_exception::build("cudaMemset(...) failed", dZ); return 1; } int memsetAsync(Nd4jPointer dst, int value, Nd4jLong size, int flags, Nd4jPointer reserved) { - cudaStream_t *pStream = reinterpret_cast(reserved); + auto pStream = reinterpret_cast(reserved); - cudaError_t dZ = cudaMemsetAsync(reinterpret_cast(dst), value, static_cast(size), *pStream); - checkCudaErrors(dZ); + auto dZ = cudaMemsetAsync(reinterpret_cast(dst), value, static_cast(size), *pStream); if (dZ != 0) - throw std::runtime_error("cudaMemsetAsync(...) failed"); + throw nd4j::cuda_exception::build("cudaMemsetAsync(...) failed", dZ); return 1; } int destroyEvent(Nd4jPointer event) { - cudaEvent_t *pEvent = reinterpret_cast(&event); - cudaError_t dZ = cudaEventDestroy(*pEvent); - checkCudaErrors(dZ); + auto pEvent = reinterpret_cast(&event); + auto dZ = cudaEventDestroy(*pEvent); if (dZ != 0) - throw std::runtime_error("cudaEvenDestroy(...) failed"); + throw nd4j::cuda_exception::build("cudaEvenDestroy(...) failed", dZ); return 1; } int streamSynchronize(Nd4jPointer stream) { - cudaStream_t *pStream = reinterpret_cast(stream); + auto pStream = reinterpret_cast(stream); - cudaError_t dZ = cudaStreamSynchronize(*pStream); - checkCudaErrors(dZ); + auto dZ = cudaStreamSynchronize(*pStream); if (dZ != 0) - throw std::runtime_error("cudaStreamSynchronize(...) failed"); + throw nd4j::cuda_exception::build("cudaStreamSynchronize(...) failed", dZ); return 1L; } int eventSynchronize(Nd4jPointer event) { - cudaEvent_t *pEvent = reinterpret_cast(&event); + auto pEvent = reinterpret_cast(&event); - cudaError_t dZ = cudaEventSynchronize(*pEvent); - checkCudaErrors(dZ); + auto dZ = cudaEventSynchronize(*pEvent); if (dZ != 0) - throw std::runtime_error("cudaEventSynchronize(...) failed"); + throw nd4j::cuda_exception::build("cudaEventSynchronize(...) failed", dZ); return 1L; } @@ -2697,13 +2692,16 @@ int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opConte auto result = op->execute(context); - // FIXME: remove once CUDA backend is 100% ready + auto res = cudaStreamSynchronize(*context->launchContext()->getCudaStream()); + if (res != 0) + throw nd4j::cuda_exception::build("customOp execution failed", res); + for (auto v:context->fastpath_in()) { - v->makeBothActual(); + v->syncToDevice(); } for (auto v:context->fastpath_out()) { - v->makeBothActual(); + v->syncToDevice(); } return result; diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index abfff5915..939db1fb7 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -36,6 +36,8 @@ namespace nd4j { static std::pair fromLongPair(LongPair* pair); static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray); + + static flatbuffers::Offset toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array); }; } } diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index ad0c5112d..bc8ff7e33 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -102,5 +102,16 @@ namespace nd4j { delete[] newShape; return array; } + + flatbuffers::Offset FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) { + auto byteVector = array.asByteVector(); + + auto fBuffer = builder.CreateVector(byteVector); + auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); + + auto bo = static_cast(BitwiseUtils::asByteOrder()); + + return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); + } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index ee9a78cee..1e3c798e2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -26,7 +26,6 @@ namespace nd4j { namespace ops { namespace helpers { - nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); template static void swapRows_(NDArray* matrix, int theFirst, int theSecond) { @@ -108,14 +107,14 @@ namespace helpers { template - static NDArray lup_(NDArray* input, NDArray* compound, NDArray* permutation) { + static NDArray lup_(LaunchContext *context, NDArray* input, NDArray* compound, NDArray* permutation) { const int rowNum = input->rows(); const int columnNum = input->columns(); NDArray determinant = NDArrayFactory::create(1.f); NDArray compoundMatrix = *input; // copy - NDArray permutationMatrix(input, false, defaultContext); // has same shape as input and contiguous strides + NDArray permutationMatrix(input, false, context); // has same shape as input and contiguous strides permutationMatrix.setIdentity(); T pivotValue; // = T(0.0); @@ -161,46 +160,43 @@ namespace helpers { return determinant; } - BUILD_SINGLE_TEMPLATE(template NDArray lup_, (NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); + BUILD_SINGLE_TEMPLATE(template NDArray lup_, (LaunchContext *context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); template - static int determinant_(NDArray* input, NDArray* output) { + static int determinant_(LaunchContext *context, NDArray* input, NDArray* output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) matrix.p(row, input->e(k)); - output->p(e, lup_(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); + output->p(e, lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr)); } return Status::OK(); } - BUILD_SINGLE_TEMPLATE(template int determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); - int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); } template - int logAbsDeterminant_(NDArray* input, NDArray* output) { + int logAbsDeterminant_(LaunchContext *context, NDArray* input, NDArray* output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; - NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), defaultContext); //, block.getWorkspace()); + NDArray matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), context); //, block.getWorkspace()); for (int e = 0; e < output->lengthOf(); e++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) { matrix.p(row, input->e(k)); } - NDArray det = lup_(&matrix, (NDArray*)nullptr, (NDArray*)nullptr); + NDArray det = lup_(context, &matrix, (NDArray*)nullptr, (NDArray*)nullptr); if (det.e(0) != 0.f) output->p(e, nd4j::math::nd4j_log(nd4j::math::nd4j_abs(det.t(0)))); } @@ -208,25 +204,23 @@ template return ND4J_STATUS_OK; } - BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES); - int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); } template - static int inverse_(NDArray* input, NDArray* output) { + static int inverse_(LaunchContext *context, NDArray* input, NDArray* output) { auto n = input->sizeAt(-1); auto n2 = n * n; auto totalCount = output->lengthOf() / n2; output->assign(0.f); // fill up output tensor with zeros - auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); - auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); //, block.getWorkspace()); - auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); - auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); - auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), defaultContext); + auto matrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto compound = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); + auto permutation = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto lowerMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); + auto upperMatrix = NDArrayFactory::create('c', {n, n}, DataTypeUtils::fromT(), context); for (int e = 0; e < totalCount; e++) { if (e) @@ -235,7 +229,7 @@ template for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { matrix.p(row++, input->e(k)); } - T det = lup_(&matrix, &compound, &permutation).template e(0); + T det = lup_(context, &matrix, &compound, &permutation).template e(0); // FIXME: and how this is going to work on float16? if (nd4j::math::nd4j_abs(det) < T(0.000001)) { @@ -268,8 +262,7 @@ template } int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); } template @@ -296,14 +289,13 @@ template return true; } - BUILD_SINGLE_TEMPLATE(template bool checkCholeskyInput_, (nd4j::LaunchContext * context, NDArray const* input), FLOAT_TYPES); bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { BUILD_SINGLE_SELECTOR(input->dataType(), return checkCholeskyInput_, (context, input), FLOAT_TYPES); } template - int cholesky_(NDArray* input, NDArray* output, bool inplace) { + int cholesky_(LaunchContext *context, NDArray* input, NDArray* output, bool inplace) { auto n = input->sizeAt(-1); auto n2 = n * n; @@ -311,8 +303,8 @@ template if (!inplace) output->assign(0.f); // fill up output tensor with zeros only inplace=false - std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), defaultContext)); //, block.getWorkspace()); - std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), defaultContext)); + std::unique_ptr matrix(NDArrayFactory::create_('c', {n, n}, input->dataType(), context)); //, block.getWorkspace()); + std::unique_ptr lowerMatrix(NDArrayFactory::create_('c',{n, n}, input->dataType(), context)); for (int e = 0; e < totalCount; e++) { @@ -346,14 +338,13 @@ template } int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { - defaultContext = context; - BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); } template - int logdetFunctor_(NDArray* input, NDArray* output) { + int logdetFunctor_(LaunchContext *context, NDArray* input, NDArray* output) { std::unique_ptr tempOutput(input->dup()); - int res = cholesky_(input, tempOutput.get(), false); + int res = cholesky_(context, input, tempOutput.get(), false); if (res != ND4J_STATUS_OK) return res; auto n = input->sizeAt(-1); @@ -372,7 +363,7 @@ template } int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (input, output), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), return logdetFunctor_, (context, input, output), FLOAT_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index e224329f0..98ab86dec 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -907,6 +907,8 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf /*** max ***/ case 0: { + coord2 = hstart; + coord3 = hend; T max = -DataTypeUtils::max(); for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index bf9c73e7c..f11b56745 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -31,8 +31,6 @@ namespace nd4j { namespace ops { namespace helpers { - nd4j::LaunchContext* defaultContext = nd4j::LaunchContext::defaultContext(); - // template // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // if (theFirst != theSecond) { @@ -198,36 +196,33 @@ namespace helpers { } template - static void invertLowerMatrix_(NDArray *inputMatrix, NDArray *invertedMatrix) { + static void invertLowerMatrix_(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); if (inputMatrix->isIdentityMatrix()) return; - auto stream = defaultContext->getCudaStream(); + auto stream = context->getCudaStream(); // invert main diagonal - upvertKernel << < 1, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + upvertKernel<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invert the second diagonal - invertKernelLow << < 1, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertKernelLow<<<1, n, 512, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); // invertKernelLow<<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); - invertLowKernel<<< n, n, 512, *stream >> > - (invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertLowKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertLowerMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { + void invertLowerMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); + BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::registerSpecialUse({invertedMatrix}, {inputMatrix}); } template - static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) { + static void invertUpperMatrix_(LaunchContext *context, NDArray* inputMatrix, NDArray* invertedMatrix) { int n = inputMatrix->rows(); invertedMatrix->setIdentity(); - auto stream = defaultContext->getCudaStream(); + auto stream = context->getCudaStream(); if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I return; } @@ -237,13 +232,12 @@ namespace helpers { inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertedMatrix->tickWriteDevice(); invertedMatrix->printIndexedBuffer("Step1 UP inversion"); - invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), - inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); + invertUpKernel<<>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(),inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); } - void invertUpperMatrix(NDArray *inputMatrix, NDArray *invertedMatrix) { + void invertUpperMatrix(LaunchContext *context, NDArray *inputMatrix, NDArray *invertedMatrix) { NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); - BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); + BUILD_SINGLE_SELECTOR(invertedMatrix->dataType(), invertUpperMatrix_, (context, inputMatrix, invertedMatrix), FLOAT_NATIVE); NDArray::prepareSpecialUse({invertedMatrix}, {inputMatrix}); } @@ -392,7 +386,6 @@ namespace helpers { auto n = input->rows(); cusolverDnHandle_t cusolverH = nullptr; cusolverStatus_t status = cusolverDnCreate(&cusolverH); - defaultContext = context; if (CUSOLVER_STATUS_SUCCESS != status) { throw cuda_exception::build("Cannot create cuSolver handle", status); } @@ -528,24 +521,19 @@ namespace helpers { input->tickWriteDevice(); } - BUILD_SINGLE_TEMPLATE(template void lup_, - (LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), - FLOAT_NATIVE); + BUILD_SINGLE_TEMPLATE(template void lup_,(LaunchContext * context, NDArray * input, NDArray * output, NDArray * permutation), FLOAT_NATIVE); template static int determinant_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), - {input->rankOf() - 2, input->rankOf() - 1}); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); // DataType dtype = input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - defaultContext = context; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), - defaultContext); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, DataTypeUtils::fromT(), context); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -554,8 +542,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { Nd4jLong pos = e * n2; // if (matrix.dataType() == input->dataType()) - fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); @@ -578,7 +565,6 @@ namespace helpers { } int determinant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -586,19 +572,16 @@ namespace helpers { template int logAbsDeterminant_(LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; Nd4jLong n = input->sizeAt(-1); Nd4jLong n2 = n * n; std::vector dims(); - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), - {input->rankOf() - 2, input->rankOf() - 1}); + auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); DataType dtype = input->dataType(); if (dtype != DataType::DOUBLE) dtype = DataType::FLOAT32; - auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, - defaultContext); //, block.getWorkspace()); + auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, context); //, block.getWorkspace()); auto det = NDArrayFactory::create(1); auto stream = context->getCudaStream(); NDArray::prepareSpecialUse({output}, {input}); @@ -607,8 +590,7 @@ namespace helpers { for (int e = 0; e < output->lengthOf(); e++) { Nd4jLong pos = e * n2; // if (matrix.dataType() == input->dataType()) - fillMatrix << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); + fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // else // fillMatrix<<>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); @@ -620,8 +602,7 @@ namespace helpers { auto inputBuf = reinterpret_cast(matrix.specialBuffer()); auto outputBuf = reinterpret_cast(output->specialBuffer()) + offset; // if (matrix.dataType() == input->dataType()) - determinantLogKernel << < launchDims.x, launchDims.y, launchDims.z, *stream >> > - (inputBuf, outputBuf, n); + determinantLogKernel<<>>(inputBuf, outputBuf, n); // else // determinantLogKernel<<>> (inputBuf, outputBuf, n); } @@ -633,7 +614,6 @@ namespace helpers { } int logAbsDeterminant(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -696,17 +676,16 @@ namespace helpers { template static int inverse_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; auto n = input->sizeAt(-1); auto n2 = n * n; auto dtype = DataTypeUtils::fromT(); //input->dataType(); // if (dtype != DataType::DOUBLE) // dtype = DataType::FLOAT32; - NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); - NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, defaultContext); + NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, context); + NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, context); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); @@ -716,20 +695,17 @@ namespace helpers { auto stream = context->getCudaStream(); for (auto i = 0LL; i < packX.numberOfTads(); i++) { - fillMatrix << < 1, n2, 1024, *stream >> > - (matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), - i * n2, n); + fillMatrix<<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); matrix.tickWriteDevice(); compound.assign(matrix); lup_(context, &compound, nullptr, nullptr); - fillLowerUpperKernel << < n, n, 1024, *stream >> > - (lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); + fillLowerUpperKernel<<>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); matrix.assign(0); - invertUpperMatrix(&upper, &matrix); // U^{-1} + invertUpperMatrix(context, &upper, &matrix); // U^{-1} matrix.tickWriteDevice(); // matrix.printIndexedBuffer("Upper Inverted"); compound.assign(0); - invertLowerMatrix(&lower, &compound); // L{-1} + invertLowerMatrix(context, &lower, &compound); // L{-1} compound.tickWriteDevice(); // compound.printIndexedBuffer("Lower Inverted"); // matrix.tickWriteDevice(); @@ -737,15 +713,12 @@ namespace helpers { nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); upper.tickWriteDevice(); // upper.printIndexedBuffer("Full inverted"); - returnMatrix << < 1, n2, 1024, *stream >> > - (output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), - i * n2, n); + returnMatrix <<<1, n2, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); } return Status::OK(); } int inverse(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); NDArray::registerSpecialUse({output}, {input}); @@ -788,7 +761,6 @@ namespace helpers { int cholesky__(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { if (!inplace) output->assign(input); - defaultContext = context; std::unique_ptr tempOutput(output->dup()); cusolverDnHandle_t handle = nullptr; auto n = input->sizeAt(-1); @@ -868,7 +840,6 @@ namespace helpers { // template int cholesky_(LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); if (input->dataType() == DataType::DOUBLE) cholesky__(context, input, output, inplace); @@ -876,8 +847,7 @@ namespace helpers { cholesky__(context, input, output, inplace); else { std::unique_ptr tempOutput( - NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, - defaultContext)); + NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, context)); tempOutput->assign(input); cholesky__(context, tempOutput.get(), tempOutput.get(), true); output->assign(tempOutput.get()); @@ -888,7 +858,6 @@ namespace helpers { int cholesky(nd4j::LaunchContext *context, NDArray *input, NDArray *output, bool inplace) { // BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES); - defaultContext = context; return cholesky_(context, input, output, inplace); } // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); @@ -927,7 +896,6 @@ namespace helpers { template int logdetFunctor_(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; NDArray::prepareSpecialUse({output}, {input}); auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto stream = context->getCudaStream(); @@ -957,7 +925,6 @@ namespace helpers { } int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { - defaultContext = context; BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); } diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp new file mode 100644 index 000000000..bf428b833 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include "testlayers.h" +#include +#include + +using namespace nd4j; + +class FlatUtilsTests : public testing::Test { +public: + +}; + +TEST_F(FlatUtilsTests, flat_float_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_int_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_bool_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {true, false, true, false}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_string_serde_1) { + auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index a023dcdd3..2ae236210 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -24,7 +24,6 @@ #include "testlayers.h" #include -using namespace nd4j; using namespace nd4j; class StringTests : public testing::Test { @@ -91,4 +90,4 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(f, z1); delete dup; -} +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml index 18680e699..21924f80a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml @@ -31,10 +31,35 @@ + + + + + org.apache.maven.plugins + maven-antrun-plugin + 1.8 + + + generate-sources + + run + + + + + + + + + + + + com.github.os72 protoc-jar-maven-plugin - 3.5.1.1 + 3.8.0 tensorflow @@ -43,30 +68,14 @@ run - java-shaded - 3.5.1 + 3.8.0 + .proto src/main/protobuf/tf + src/main/protobuf/onnx src/main/protobuf/tf/tensorflow - - main - false - src/main/java/ - - - - onnx - generate-sources - - run - - - java-shaded - .proto3 - 3.5.1 - src/main/protobuf/onnx main @@ -76,6 +85,32 @@ + + + com.google.code.maven-replacer-plugin + replacer + 1.5.3 + + + ${project.build.sourceDirectory}/org/tensorflow/** + ${project.build.sourceDirectory}/tensorflow/** + ${project.build.sourceDirectory}/onnx/** + + com.google.protobuf. + org.nd4j.shade.protobuf. + + + + replace-imports + generate-sources + + replace + + + + + + org.apache.maven.plugins maven-compiler-plugin @@ -148,20 +183,15 @@ ${flatbuffers.version} - + - com.github.os72 - protobuf-java-shaded-351 - 0.9 - - - com.github.os72 - protobuf-java-util-shaded-351 - 0.9 + org.nd4j + protobuf + ${project.version} + org.objenesis objenesis diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 2d49ce56f..71bbd26ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; @@ -101,10 +101,10 @@ public abstract class DifferentialFunction { /** * Initialize the function from the given - * {@link onnx.OnnxProto3.NodeProto} + * {@link onnx.Onnx.NodeProto} * @param node */ - public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map attributesForNode, OnnxProto3.GraphProto graph) { + public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map attributesForNode, Onnx.GraphProto graph) { this.sameDiff = sameDiff; setInstanceId(); initFromOnnx(node, sameDiff, attributesForNode, graph); @@ -731,13 +731,13 @@ public abstract class DifferentialFunction { /** * Iniitialize the function from the given - * {@link onnx.OnnxProto3.NodeProto} + * {@link onnx.Onnx.NodeProto} * @param node * @param initWith * @param attributesForNode * @param graph */ - public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph); + public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index a7fb35520..430b4d83a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff; import java.util.Objects; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/descriptors/tensorflow/TensorflowDescriptorParser.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/descriptors/tensorflow/TensorflowDescriptorParser.java index fad55d101..3d0464782 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/descriptors/tensorflow/TensorflowDescriptorParser.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/descriptors/tensorflow/TensorflowDescriptorParser.java @@ -16,7 +16,7 @@ package org.nd4j.imports.descriptors.tensorflow; -import com.github.os72.protobuf351.TextFormat; +import org.nd4j.shade.protobuf.TextFormat; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.io.ClassPathResource; import org.tensorflow.framework.OpDef; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java index 92c888e0c..fe252aeeb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/BaseGraphMapper.java @@ -16,8 +16,8 @@ package org.nd4j.imports.graphmapper; -import com.github.os72.protobuf351.Message; -import com.github.os72.protobuf351.TextFormat; +import org.nd4j.shade.protobuf.Message; +import org.nd4j.shade.protobuf.TextFormat; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.IOUtils; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java index 8aad0f4d9..2d89a2b07 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/GraphMapper.java @@ -16,7 +16,7 @@ package org.nd4j.imports.graphmapper; -import com.github.os72.protobuf351.Message; +import org.nd4j.shade.protobuf.Message; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java index 0bbfece6f..719ac792d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/onnx/OnnxGraphMapper.java @@ -16,13 +16,13 @@ package org.nd4j.imports.graphmapper.onnx; -import com.github.os72.protobuf351.ByteString; -import com.github.os72.protobuf351.Message; +import org.nd4j.shade.protobuf.ByteString; +import org.nd4j.shade.protobuf.Message; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -52,7 +52,7 @@ import java.util.*; * * @author Adam Gibson */ -public class OnnxGraphMapper extends BaseGraphMapper { +public class OnnxGraphMapper extends BaseGraphMapper { private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper(); @@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) { + public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) { val properties = on.mappingsForFunction(); val tfProperties = properties.get(mappedTfName); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); @@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper> propertyMappingsForFunction) { + public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map> propertyMappingsForFunction) { val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node)); val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on); /** @@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper getControlDependencies(OnnxProto3.NodeProto node) { + public List getControlDependencies(Onnx.NodeProto node) { throw new UnsupportedOperationException("Not yet implemented"); } @Override public void dumpBinaryProtoAsText(File inputFile, File outputFile) { try { - OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); + Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile))); BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true)); - for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) { + for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) { bufferedWriter.write(node.toString()); } @@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper variablesForGraph(OnnxProto3.GraphProto graphProto) { + public Map variablesForGraph(Onnx.GraphProto graphProto) { /** * Need to figure out why * gpu_0/conv1_1 isn't present in VGG */ - Map ret = new HashMap<>(); + Map ret = new HashMap<>(); for(int i = 0; i < graphProto.getInputCount(); i++) { ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType()); } @@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper to) { - OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension. + protected void addDummyTensor(String name, Map to) { + Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension. newBuilder() .setDimValue(-1) .build(); - OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder() + Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder() .setShape( - OnnxProto3.TensorShapeProto.newBuilder() + Onnx.TensorShapeProto.newBuilder() .addDim(dim) .addDim(dim).build()) .build(); @@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper importState, - OpImportOverride opImportOverride, - OpImportFilter opFilter) { + public void mapNodeType(Onnx.NodeProto tfNode, ImportState importState, + OpImportOverride opImportOverride, + OpImportFilter opFilter) { val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType()); if(differentialFunction == null) { throw new NoOpNameFoundException("No op name found " + tfNode.getOpType()); @@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper= 2) @@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper= 2) @@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper getAttrMap(OnnxProto3.NodeProto nodeProto) { - Map proto = new HashMap<>(); + public Map getAttrMap(Onnx.NodeProto nodeProto) { + Map proto = new HashMap<>(); for(int i = 0; i < nodeProto.getAttributeCount(); i++) { - OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i); + Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i); proto.put(attributeProto.getName(),attributeProto); } return proto; } @Override - public String getName(OnnxProto3.NodeProto nodeProto) { + public String getName(Onnx.NodeProto nodeProto) { return nodeProto.getName(); } @Override - public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) { + public boolean alreadySeen(Onnx.NodeProto nodeProto) { return false; } @Override - public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) { + public boolean isVariableNode(Onnx.NodeProto nodeProto) { return nodeProto.getOpType().contains("Var"); } @Override - public boolean shouldSkip(OnnxProto3.NodeProto opType) { + public boolean shouldSkip(Onnx.NodeProto opType) { return false; } @Override - public boolean hasShape(OnnxProto3.NodeProto nodeProto) { + public boolean hasShape(Onnx.NodeProto nodeProto) { return false; } @Override - public long[] getShape(OnnxProto3.NodeProto nodeProto) { + public long[] getShape(Onnx.NodeProto nodeProto) { return null; } @Override - public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) { + public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) { return null; } @Override - public String getOpType(OnnxProto3.NodeProto nodeProto) { + public String getOpType(Onnx.NodeProto nodeProto) { return nodeProto.getOpType(); } @Override - public List getNodeList(OnnxProto3.GraphProto graphProto) { + public List getNodeList(Onnx.GraphProto graphProto) { return graphProto.getNodeList(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 5579569c3..f57fef4c7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -16,7 +16,7 @@ package org.nd4j.imports.graphmapper.tf; -import com.github.os72.protobuf351.Message; +import org.nd4j.shade.protobuf.Message; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; import lombok.extern.slf4j.Slf4j; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java index 722168541..e9a99c6c0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java @@ -1,6 +1,6 @@ package org.nd4j.imports.graphmapper.tf.tensors; -import com.github.os72.protobuf351.Descriptors; +import org.nd4j.shade.protobuf.Descriptors; import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer; import org.bytedeco.javacpp.indexer.HalfIndexer; import org.nd4j.linalg.api.buffer.DataType; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index a844b04c7..a41dc8790 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index d65ff377e..7f0d7e40c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 925a5924f..8c9cdf4e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -20,7 +20,7 @@ import lombok.Data; import lombok.Getter; import lombok.Setter; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index ebf9b9c18..7fc0679db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; @@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { if (!attributesForNode.containsKey("axes")) { this.dimensions = new int[] { Integer.MAX_VALUE }; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 9b5b190c1..f52450eee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles; import com.google.common.primitives.Longs; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index 19d8fe987..6b174bd07 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; @@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java index 6e0db97a5..03dc26313 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/If.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java index eba0e1145..e26b0ea5f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/While.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow; import lombok.*; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java index 378fbb06b..fd2134aad 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/ExternalErrorsFunction.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java index da2c26f54..27f357b4b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/Linear.java @@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers; import lombok.Builder; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -96,7 +96,7 @@ public class Linear extends BaseModule { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index 3198a6a56..ac13c6224 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java index 2c57c68de..6f58884f0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("Not yet implemented"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java index 67fc9f3a5..bad975cb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/BatchNorm.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; @@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java index 3d61de716..5ae2ac144 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv1D.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 4335f4561..04db5874c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java index 6cba853d0..65c0fccc3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val dilations = attributesForNode.get("dilations"); val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index 0ea84e081..92a39f188 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -20,7 +20,7 @@ import lombok.Builder; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); addArgs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index de4e763bc..421598d13 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val aAlpha = attributesForNode.get("alpha"); val aBeta = attributesForNode.get("beta"); val aBias = attributesForNode.get("bias"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index f996fc29f..b321334a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -21,7 +21,7 @@ import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val isSameNode = paddingVal.equals("SAME"); val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java index a243dec9b..99d73d2af 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution; import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("Not yet implemented"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java index f7f21e78d..c45d106e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Pooling2D.java @@ -20,7 +20,7 @@ import lombok.Builder; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME"); val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); val padding = attributesForNode.get("pads").getIntsList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java index 678a4afef..6c7daca69 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUCell.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java index 1fdd6b191..e9d2ffd3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMCell.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration; @@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java index aaac14131..b916d4961 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRU.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java index 625e09e91..4880b90fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/SRUCell.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 00cad1f88..7d711ca58 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce; import lombok.EqualsAndHashCode; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; MMulTranspose mMulTranspose = MMulTranspose.builder() diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 62e373832..3de44537a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -20,7 +20,7 @@ import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; MMulTranspose mMulTranspose = MMulTranspose.builder() diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index e0b0450d3..5c6beb945 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java index 90aed14bf..b6d08784b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Diag.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java index d2807e36d..6b1688602 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/DiagPart.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index 1782f75df..31718d337 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; @@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index 1be7c56bf..cfe4fe8be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index d8319cab2..ec86c6553 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index c9118990c..046f06c3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeSum.java index b7c370615..6b87ca5c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeSum.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java index 8d7dcf6a6..1856e6804 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ParallelStack.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("No analog found for onnx for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java index 96f28dbf1..aacfa19e1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Rank.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index 14d67d912..02f8f9445 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 42e401859..b30bacc22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val shape = new OnnxGraphMapper().getShape(node); this.shape = shape; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 5faa82609..a2f6bd208 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; import lombok.val; -import onnx.OnnxMlProto3; +import onnx.OnnxMl; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java index a1133ee82..6cd2eec06 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Shape.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; @@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new NoOpNameFoundException("No onnx name found for shape " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java index 241cc950f..55d9dd806 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java index 1ba9156bc..71b52a92a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Size.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 44cb0539c..6cd09f9bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("No analog found for onnx for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 965d071c3..2de0a29c5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import com.google.common.primitives.Ints; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; @@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { if (!attributesForNode.containsKey("perm")) { } else diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java index 3d7e07a72..9dd6b6338 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Unstack.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("No analog found for onnx for " + opName()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/ConcatBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/ConcatBp.java index ead0f2747..70bc1b087 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/ConcatBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/ConcatBp.java @@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp; import lombok.extern.slf4j.Slf4j; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { //No op } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java index 07bdab586..7759f96dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayConcat.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java index 3ab0d91c9..9e7669725 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayGather.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java index 619216813..6d8cff91c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayRead.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayScatter.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayScatter.java index add288d89..9e1d93e2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayScatter.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArrayScatter.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; @@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java index 9734515d3..276dadcab 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySize.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; @@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySplit.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySplit.java index fb52c78a7..589805641 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySplit.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArraySplit.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape.tensorops; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; @@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp { @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index 1cda0257d..59b7ec2f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.clip; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("Not yet implemented"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index d25b0df62..11d3e9004 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.clip; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { throw new UnsupportedOperationException("Not yet implemented"); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java index 7ca0b342a..35c209870 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index d1d0176ef..9c04aeb12 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index 2b62b73cf..b8c7d5c51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index db95ee728..af4097870 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index 4bd56ea4d..da439cec7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -16,7 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOutInverted.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOutInverted.java index bb4b86f12..6b174ae63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOutInverted.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOutInverted.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.random.impl; import lombok.NonNull; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp { } @Override - public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map attributesForNode, OnnxProto3.GraphProto graph) { + public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { super.initFromOnnx(node, initWith, attributesForNode, graph); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index 27e9d9f3c..c3670b52f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -17,7 +17,7 @@ package org.nd4j.linalg.api.ops.random.impl; import lombok.val; -import onnx.OnnxProto3; +import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-ml.proto3 b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-ml.proto similarity index 100% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-ml.proto3 rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-ml.proto diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto3 b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto similarity index 99% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto3 rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto index a8db3ca23..48890a516 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto3 +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx-operators.proto @@ -9,7 +9,7 @@ syntax = "proto3"; package onnx; -import "onnx.proto3"; +import "onnx.proto"; // // This file contains the proto definitions for OperatorSetProto and diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx.proto3 b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx.proto similarity index 100% rename from nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx.proto3 rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/protobuf/onnx/onnx.proto diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java index af1af3a75..ee188605d 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/cpujava/org/nd4j/tensorflow/conversion/GraphRunnerTest.java @@ -16,7 +16,7 @@ package org.nd4j.tensorflow.conversion; -import com.github.os72.protobuf351.util.JsonFormat; +import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Ignore; import org.junit.Rule; diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index accde5b1b..1ecc0e39a 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -16,7 +16,7 @@ package org.nd4j.tensorflow.conversion; -import com.github.os72.protobuf351.util.JsonFormat; +import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Ignore; import org.junit.Test; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index c2f5dedc5..6c4595a0c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -732,4 +732,20 @@ public class CustomOpsTests extends BaseNd4jTest { fail("Failed datatypes: " + failed.toString()); } } + + @Test + public void testMaxPool2Dbp_1() { + val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN); + val y = Nd4j.create(DataType.HALF, 2,3,8,8).assign(Double.NaN); + val z = Nd4j.create(DataType.HALF, 2,3,16,16); + + val op = DynamicCustomOp.builder("maxpool2d_bp") + .addInputs(x, y) + .addOutputs(z) + .addIntegerArguments(2, 2, 2, 2, 8,8, 1,1,1, 0,0) + .build(); + + Nd4j.exec(op); + Nd4j.getExecutioner().commit(); + } } diff --git a/nd4j/nd4j-shade/pom.xml b/nd4j/nd4j-shade/pom.xml index 4a2c4ca1b..36b58087b 100644 --- a/nd4j/nd4j-shade/pom.xml +++ b/nd4j/nd4j-shade/pom.xml @@ -29,6 +29,7 @@ pom jackson + protobuf diff --git a/nd4j/nd4j-shade/protobuf/pom.xml b/nd4j/nd4j-shade/protobuf/pom.xml new file mode 100644 index 000000000..1cbd7d5a8 --- /dev/null +++ b/nd4j/nd4j-shade/protobuf/pom.xml @@ -0,0 +1,228 @@ + + + + nd4j-shade + org.nd4j + 1.0.0-SNAPSHOT + + 4.0.0 + + protobuf + + + true + + + + + com.google.protobuf + protobuf-java + 3.8.0 + + + com.google.protobuf + protobuf-java-util + 3.8.0 + + + + + + + custom-lifecycle + + + !skip.custom.lifecycle + + + + + + org.apache.portals.jetspeed-2 + jetspeed-mvn-maven-plugin + 2.3.1 + + + compile-and-pack + compile + + mvn + + + + + + org.apache.maven.shared + maven-invoker + 2.2 + + + + + + + create-shaded-jars + @rootdir@/nd4j/nd4j-shade/protobuf/ + clean,compile,package + + true + + + + + create-shaded-jars + + + + + + + + + + + + + + com.lewisd + lint-maven-plugin + 0.0.11 + + + pom-lint + none + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + ${maven-shade-plugin.version} + + + package + + shade + + + + + reference.conf + + + + + + + + + + + + false + true + true + + + + com.google.protobuf:* + com.google.protobuf.*:* + + + + + + + com.google.protobuf + org.nd4j.shade.protobuf + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + true + + + + empty-javadoc-jar + package + + jar + + + javadoc + ${basedir}/javadoc + + + + empty-sources-jar + package + + jar + + + sources + ${basedir}/src + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 3.0.0 + + + unpack + package + + unpack + + + + + org.nd4j + protobuf + ${project.version} + jar + false + ${project.build.directory}/classes/ + **/*.class,**/*.xml + + + + + + + + + + \ No newline at end of file diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java index b47cd30d1..6eff18ecc 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java @@ -16,7 +16,7 @@ package org.nd4j.tensorflow.conversion; -import com.github.os72.protobuf351.InvalidProtocolBufferException; +import org.nd4j.shade.protobuf.InvalidProtocolBufferException; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; import org.nd4j.linalg.api.buffer.DataBuffer; diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java index 633535197..79d45f781 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java @@ -16,9 +16,9 @@ package org.nd4j.tensorflow.conversion.graphrunner; -import com.github.os72.protobuf351.ByteString; -import com.github.os72.protobuf351.InvalidProtocolBufferException; -import com.github.os72.protobuf351.util.JsonFormat; +import org.nd4j.shade.protobuf.ByteString; +import org.nd4j.shade.protobuf.InvalidProtocolBufferException; +import org.nd4j.shade.protobuf.util.JsonFormat; import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; @@ -638,7 +638,7 @@ public class GraphRunner implements Closeable { /** * Convert a json string written out - * by {@link com.github.os72.protobuf351.util.JsonFormat} + * by {@link org.nd4j.shade.protobuf.util.JsonFormat} * to a {@link org.bytedeco.tensorflow.ConfigProto} * @param json the json to read * @return the config proto to use