Upgrade protobuf version (#162)

* First steps for protobuf version upgrade

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Phase 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update imports to shaded protobuf

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Version fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch to single execution for protobuf codegen to work around plugin bug

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Automatically delete old PB generated files after name change

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2019-08-24 19:22:36 +10:00 committed by GitHub
parent b85238a6df
commit a9b08cc163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 487 additions and 228 deletions

View File

@ -31,10 +31,35 @@
<build>
<plugins>
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
change will run into a compilation error. This can be removed after a few weeks.-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-antrun-plugin</artifactId>
<version>1.8</version>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<target>
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
</target>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.github.os72</groupId>
<artifactId>protoc-jar-maven-plugin</artifactId>
<version>3.5.1.1</version>
<version>3.8.0</version>
<executions>
<execution>
<id>tensorflow</id>
@ -43,30 +68,14 @@
<goal>run</goal>
</goals>
<configuration>
<type>java-shaded</type>
<protocVersion>3.5.1</protocVersion>
<protocVersion>3.8.0</protocVersion>
<extension>.proto</extension>
<includeDirectories>
<include>src/main/protobuf/tf</include>
<include>src/main/protobuf/onnx</include>
</includeDirectories>
<inputDirectories>
<include>src/main/protobuf/tf/tensorflow</include>
</inputDirectories>
<addSources>main</addSources>
<cleanOutputFolder>false</cleanOutputFolder>
<outputDirectory>src/main/java/</outputDirectory>
</configuration>
</execution>
<execution>
<id>onnx</id>
<phase>generate-sources</phase>
<goals>
<goal>run</goal>
</goals>
<configuration>
<type>java-shaded</type>
<extension>.proto3</extension>
<protocVersion>3.5.1</protocVersion>
<inputDirectories>
<include>src/main/protobuf/onnx</include>
</inputDirectories>
<addSources>main</addSources>
@ -76,6 +85,32 @@
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.google.code.maven-replacer-plugin</groupId>
<artifactId>replacer</artifactId>
<version>1.5.3</version>
<configuration>
<includes>
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
<include>${project.build.sourceDirectory}/tensorflow/**</include>
<include>${project.build.sourceDirectory}/onnx/**</include>
</includes>
<token>com.google.protobuf.</token>
<value>org.nd4j.shade.protobuf.</value>
</configuration>
<executions>
<execution>
<id>replace-imports</id>
<phase>generate-sources</phase>
<goals>
<goal>replace</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
@ -148,20 +183,15 @@
<version>${flatbuffers.version}</version>
</dependency>
<!-- Note that this is shaded flatbuffers, see the protoc declaration above
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging
their own older protobuf versions-->
<!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
their own older (incompatible) protobuf versions-->
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-java-shaded-351</artifactId>
<version>0.9</version>
</dependency>
<dependency>
<groupId>com.github.os72</groupId>
<artifactId>protobuf-java-util-shaded-351</artifactId>
<version>0.9</version>
<groupId>org.nd4j</groupId>
<artifactId>protobuf</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId>

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
/**
* Initialize the function from the given
* {@link onnx.OnnxProto3.NodeProto}
* {@link onnx.Onnx.NodeProto}
* @param node
*/
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
this.sameDiff = sameDiff;
setInstanceId();
initFromOnnx(node, sameDiff, attributesForNode, graph);
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
/**
* Iniitialize the function from the given
* {@link onnx.OnnxProto3.NodeProto}
* {@link onnx.Onnx.NodeProto}
* @param node
* @param initWith
* @param attributesForNode
* @param graph
*/
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph);
public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);

View File

@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
import java.util.Objects;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.descriptors.tensorflow;
import com.github.os72.protobuf351.TextFormat;
import org.nd4j.shade.protobuf.TextFormat;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.io.ClassPathResource;
import org.tensorflow.framework.OpDef;

View File

@ -16,8 +16,8 @@
package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message;
import com.github.os72.protobuf351.TextFormat;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;

View File

@ -16,13 +16,13 @@
package org.nd4j.imports.graphmapper.onnx;
import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -52,7 +52,7 @@ import java.util.*;
*
* @author Adam Gibson
*/
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> {
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString() + "\n");
}
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
public boolean isOpIgnoreException(Onnx.NodeProto node) {
return false;
}
@Override
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
return function.opName();
}
@Override
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
/**
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i);
if(node.getName().equals(name))
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
return false;
}
@Override
public List<String> getControlDependencies(OnnxProto3.NodeProto node) {
public List<String> getControlDependencies(Onnx.NodeProto node) {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try {
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString());
}
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
/**
* Need to figure out why
* gpu_0/conv1_1 isn't present in VGG
*/
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
for(int i = 0; i < graphProto.getInputCount(); i++) {
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
}
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
return null;
}
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
newBuilder()
.setDimValue(-1)
.build();
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder()
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
.setShape(
OnnxProto3.TensorShapeProto.newBuilder()
Onnx.TensorShapeProto.newBuilder()
.addDim(dim)
.addDim(dim).build())
.build();
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public Message.Builder getNewGraphBuilder() {
return OnnxProto3.GraphProto.newBuilder();
return Onnx.GraphProto.newBuilder();
}
@Override
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState,
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride,
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) {
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
if(differentialFunction == null) {
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) {
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
return nd4jTypeFromOnnxType(tensorProto.getElemType());
}
@Override
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) {
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING;
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
}
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
* @param dataType the data type to convert
* @return the nd4j type for the onnx type
*/
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
switch (dataType) {
case DOUBLE: return DataType.DOUBLE;
case FLOAT: return DataType.FLOAT;
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
if(attributeProto.getName().equals(key)) {
return attributeProto.getS().toString();
}
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
return Longs.toArray(attributeProto.getT().getDimsList());
}
@Override
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
DataType type = dataTypeForTensor(tensorProto, 0);
if(!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
OnnxProto3.TensorProto tensor = null;
Onnx.TensorProto tensor = null;
for(int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if(initializer.getName().equals(tensorName)) {
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
return arr;
}
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
if(tensor == null)
return null;
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
}
@Override
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) {
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
int dimCount = tensorProto.getShape().getDimCount();
if(dimCount >= 2)
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
/**
* Get the shape from a tensor proto.
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)}
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
* @param tensorProto the tensor to get the shape from
* @return
*/
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
int dimCount = tensorProto.getDimsCount();
if(dimCount >= 2)
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
@Override
public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
public String getInputFromNode(Onnx.NodeProto node, int index) {
return node.getInput(index);
}
@Override
public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
public int numInputsFor(Onnx.NodeProto nodeProto) {
return nodeProto.getInputCount();
}
@Override
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
return Longs.toArray(attr.getT().getDimsList());
}
@Override
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>();
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
proto.put(attributeProto.getName(),attributeProto);
}
return proto;
}
@Override
public String getName(OnnxProto3.NodeProto nodeProto) {
public String getName(Onnx.NodeProto nodeProto) {
return nodeProto.getName();
}
@Override
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType().contains("Var");
}
@Override
public boolean shouldSkip(OnnxProto3.NodeProto opType) {
public boolean shouldSkip(Onnx.NodeProto opType) {
return false;
}
@Override
public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
public boolean hasShape(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public long[] getShape(OnnxProto3.NodeProto nodeProto) {
public long[] getShape(Onnx.NodeProto nodeProto) {
return null;
}
@Override
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
return null;
}
@Override
public String getOpType(OnnxProto3.NodeProto nodeProto) {
public String getOpType(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType();
}
@Override
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
return graphProto.getNodeList();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.imports.graphmapper.tf;
import com.github.os72.protobuf351.Message;
import org.nd4j.shade.protobuf.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import lombok.extern.slf4j.Slf4j;

View File

@ -1,6 +1,6 @@
package org.nd4j.imports.graphmapper.tf.tensors;
import com.github.os72.protobuf351.Descriptors;
import org.nd4j.shade.protobuf.Descriptors;
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.linalg.api.buffer.DataType;

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -20,7 +20,7 @@ import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("axes")) {
this.dimensions = new int[] { Integer.MAX_VALUE };
}

View File

@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
import com.google.common.primitives.Longs;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
import lombok.Builder;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val dilations = attributesForNode.get("dilations");
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
}

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val aAlpha = attributesForNode.get("alpha");
val aBeta = attributesForNode.get("beta");
val aBias = attributesForNode.get("bias");

View File

@ -21,7 +21,7 @@ import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
val isSameNode = paddingVal.equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -20,7 +20,7 @@ import lombok.Builder;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
val padding = attributesForNode.get("pads").getIntsList();

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.EqualsAndHashCode;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
MMulTranspose mMulTranspose = MMulTranspose.builder()

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
val shape = new OnnxGraphMapper().getShape(node);
this.shape = shape;
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.val;
import onnx.OnnxMlProto3;
import onnx.OnnxMl;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
import com.google.common.primitives.Ints;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("perm")) {
} else

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
}

View File

@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
//No op
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException();
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.clip;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
throw new UnsupportedOperationException("Not yet implemented");
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -16,7 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl;
import lombok.NonNull;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
}
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}

View File

@ -17,7 +17,7 @@
package org.nd4j.linalg.api.ops.random.impl;
import lombok.val;
import onnx.OnnxProto3;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

View File

@ -9,7 +9,7 @@
syntax = "proto3";
package onnx;
import "onnx.proto3";
import "onnx.proto";
//
// This file contains the proto definitions for OperatorSetProto and

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Ignore;
import org.junit.Rule;

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils;
import org.junit.Ignore;
import org.junit.Test;

View File

@ -29,6 +29,7 @@
<packaging>pom</packaging>
<modules>
<module>jackson</module>
<module>protobuf</module>
</modules>
<properties>

View File

@ -0,0 +1,228 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nd4j-shade</artifactId>
<groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>protobuf</artifactId>
<properties>
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
</properties>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.8.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
<version>3.8.0</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>custom-lifecycle</id>
<activation>
<property><name>!skip.custom.lifecycle</name></property>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.portals.jetspeed-2</groupId>
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
<version>2.3.1</version>
<executions>
<execution>
<id>compile-and-pack</id>
<phase>compile</phase>
<goals>
<goal>mvn</goal>
</goals>
</execution>
</executions>
<dependencies>
<dependency>
<groupId>org.apache.maven.shared</groupId>
<artifactId>maven-invoker</artifactId>
<version>2.2</version>
</dependency>
</dependencies>
<configuration>
<targets combine.children="merge">
<target>
<id>create-shaded-jars</id>
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
<goals>clean,compile,package</goals>
<properties>
<skip.custom.lifecycle>true</skip.custom.lifecycle>
</properties>
</target>
</targets>
<defaultTarget>create-shaded-jars</defaultTarget>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<plugins>
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
<plugin>
<groupId>com.lewisd</groupId>
<artifactId>lint-maven-plugin</artifactId>
<version>0.0.11</version>
<executions>
<execution>
<id>pom-lint</id>
<phase>none</phase>
</execution>
</executions>
</plugin>
<!--
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
including this module (org.nd4j.protobuf) as a dependency.
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
</transformer>
</transformers>
</configuration>
</execution>
</executions>
<configuration>
<!--
Important configuration options here:
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
original dependencies will be shaded, AND still included as transitive deps
in the final POM. This is not what we want.
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
the original un-shaded JAR being separate. With this being set to false,
the original JAR will be modified, and no extra jar will be produced.
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
to direct dependencies. Without this, we need to manually manage the transitive
dependencies of the shaded artifacts.
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
-->
<shadedArtifactAttached>false</shadedArtifactAttached>
<createDependencyReducedPom>true</createDependencyReducedPom>
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
<artifactSet>
<includes>
<include>com.google.protobuf:*</include>
<include>com.google.protobuf.*:*</include>
</includes>
</artifactSet>
<relocations>
<!-- Protobuf dependencies -->
<relocation>
<pattern>com.google.protobuf</pattern>
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
</relocation>
</relocations>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<forceCreation>true</forceCreation>
</configuration>
<executions>
<execution>
<id>empty-javadoc-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>javadoc</classifier>
<classesDirectory>${basedir}/javadoc</classesDirectory>
</configuration>
</execution>
<execution>
<id>empty-sources-jar</id>
<phase>package</phase>
<goals>
<goal>jar</goal>
</goals>
<configuration>
<classifier>sources</classifier>
<classesDirectory>${basedir}/src</classesDirectory>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-dependency-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>unpack</id>
<phase>package</phase>
<goals>
<goal>unpack</goal>
</goals>
<configuration>
<artifactItems>
<artifactItem>
<groupId>org.nd4j</groupId>
<artifactId>protobuf</artifactId>
<version>${project.version}</version>
<type>jar</type>
<overWrite>false</overWrite>
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
<includes>**/*.class,**/*.xml</includes>
</artifactItem>
</artifactItems>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View File

@ -16,7 +16,7 @@
package org.nd4j.tensorflow.conversion;
import com.github.os72.protobuf351.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.DataBuffer;

View File

@ -16,9 +16,9 @@
package org.nd4j.tensorflow.conversion.graphrunner;
import com.github.os72.protobuf351.ByteString;
import com.github.os72.protobuf351.InvalidProtocolBufferException;
import com.github.os72.protobuf351.util.JsonFormat;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.util.JsonFormat;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
/**
* Convert a json string written out
* by {@link com.github.os72.protobuf351.util.JsonFormat}
* by {@link org.nd4j.shade.protobuf.util.JsonFormat}
* to a {@link org.bytedeco.tensorflow.ConfigProto}
* @param json the json to read
* @return the config proto to use