Upgrade protobuf version (#162)
* First steps for protobuf version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * Phase 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update imports to shaded protobuf Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch to single execution for protobuf codegen to work around plugin bug Signed-off-by: AlexDBlack <blacka101@gmail.com> * Automatically delete old PB generated files after name change Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
b85238a6df
commit
a9b08cc163
|
@ -31,10 +31,35 @@
|
|||
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<!-- AB 2019/08/24 This plugin is to be added TEMPORARILY due to a change in the filenames of the generated ONNX -->
|
||||
<!-- Normal "mvn clean" etc won't delete these files, and any users who have built ND4J even once before the
|
||||
change will run into a compilation error. This can be removed after a few weeks.-->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-antrun-plugin</artifactId>
|
||||
<version>1.8</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<target>
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxMlProto3.java" />
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxOperatorsProto3.java" />
|
||||
<delete file="${project.build.sourceDirectory}/onnx/OnnxProto3.java" />
|
||||
</target>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protoc-jar-maven-plugin</artifactId>
|
||||
<version>3.5.1.1</version>
|
||||
<version>3.8.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>tensorflow</id>
|
||||
|
@ -43,30 +68,14 @@
|
|||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<type>java-shaded</type>
|
||||
<protocVersion>3.5.1</protocVersion>
|
||||
<protocVersion>3.8.0</protocVersion>
|
||||
<extension>.proto</extension>
|
||||
<includeDirectories>
|
||||
<include>src/main/protobuf/tf</include>
|
||||
<include>src/main/protobuf/onnx</include>
|
||||
</includeDirectories>
|
||||
<inputDirectories>
|
||||
<include>src/main/protobuf/tf/tensorflow</include>
|
||||
</inputDirectories>
|
||||
<addSources>main</addSources>
|
||||
<cleanOutputFolder>false</cleanOutputFolder>
|
||||
<outputDirectory>src/main/java/</outputDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>onnx</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>run</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<type>java-shaded</type>
|
||||
<extension>.proto3</extension>
|
||||
<protocVersion>3.5.1</protocVersion>
|
||||
<inputDirectories>
|
||||
<include>src/main/protobuf/onnx</include>
|
||||
</inputDirectories>
|
||||
<addSources>main</addSources>
|
||||
|
@ -76,6 +85,32 @@
|
|||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>com.google.code.maven-replacer-plugin</groupId>
|
||||
<artifactId>replacer</artifactId>
|
||||
<version>1.5.3</version>
|
||||
<configuration>
|
||||
<includes>
|
||||
<include>${project.build.sourceDirectory}/org/tensorflow/**</include>
|
||||
<include>${project.build.sourceDirectory}/tensorflow/**</include>
|
||||
<include>${project.build.sourceDirectory}/onnx/**</include>
|
||||
</includes>
|
||||
<token>com.google.protobuf.</token>
|
||||
<value>org.nd4j.shade.protobuf.</value>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>replace-imports</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>replace</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
|
@ -148,20 +183,15 @@
|
|||
<version>${flatbuffers.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Note that this is shaded flatbuffers, see the protoc declaration above
|
||||
mentioning java-shaded as the type for why we use this instead of google's (mainly due ot other systems packaging
|
||||
their own older protobuf versions-->
|
||||
<!-- Note that this is shaded protobuf. We use this instead of google's version mainly due ot other systems packaging
|
||||
their own older (incompatible) protobuf versions-->
|
||||
<dependency>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protobuf-java-shaded-351</artifactId>
|
||||
<version>0.9</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.os72</groupId>
|
||||
<artifactId>protobuf-java-util-shaded-351</artifactId>
|
||||
<version>0.9</version>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>protobuf</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.objenesis</groupId>
|
||||
<artifactId>objenesis</artifactId>
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
|
@ -101,10 +101,10 @@ public abstract class DifferentialFunction {
|
|||
|
||||
/**
|
||||
* Initialize the function from the given
|
||||
* {@link onnx.OnnxProto3.NodeProto}
|
||||
* {@link onnx.Onnx.NodeProto}
|
||||
* @param node
|
||||
*/
|
||||
public DifferentialFunction(SameDiff sameDiff,onnx.OnnxProto3.NodeProto node,Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public DifferentialFunction(SameDiff sameDiff,onnx.Onnx.NodeProto node,Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
this.sameDiff = sameDiff;
|
||||
setInstanceId();
|
||||
initFromOnnx(node, sameDiff, attributesForNode, graph);
|
||||
|
@ -731,13 +731,13 @@ public abstract class DifferentialFunction {
|
|||
|
||||
/**
|
||||
* Iniitialize the function from the given
|
||||
* {@link onnx.OnnxProto3.NodeProto}
|
||||
* {@link onnx.Onnx.NodeProto}
|
||||
* @param node
|
||||
* @param initWith
|
||||
* @param attributesForNode
|
||||
* @param graph
|
||||
*/
|
||||
public abstract void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph);
|
||||
public abstract void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.autodiff.samediff;
|
|||
import java.util.Objects;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.descriptors.tensorflow;
|
||||
|
||||
import com.github.os72.protobuf351.TextFormat;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.tensorflow.framework.OpDef;
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import com.github.os72.protobuf351.TextFormat;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper.onnx;
|
||||
|
||||
import com.github.os72.protobuf351.ByteString;
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import com.google.common.primitives.Floats;
|
||||
import com.google.common.primitives.Ints;
|
||||
import com.google.common.primitives.Longs;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -52,7 +52,7 @@ import java.util.*;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto, onnx.OnnxProto3.TypeProto.Tensor> {
|
||||
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
|
||||
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
|
||||
|
||||
|
||||
|
@ -64,9 +64,9 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
@Override
|
||||
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
|
||||
try {
|
||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(inputFile);
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString() + "\n");
|
||||
}
|
||||
|
||||
|
@ -88,7 +88,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
* @param node
|
||||
* @param graph
|
||||
*/
|
||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph) {
|
||||
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
|
||||
val properties = on.mappingsForFunction();
|
||||
val tfProperties = properties.get(mappedTfName);
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
|
@ -170,18 +170,18 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isOpIgnoreException(OnnxProto3.NodeProto node) {
|
||||
public boolean isOpIgnoreException(Onnx.NodeProto node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node) {
|
||||
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
|
||||
return function.opName();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
|
||||
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
|
||||
/**
|
||||
|
@ -263,7 +263,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name) {
|
||||
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
|
||||
for(int i = 0; i < graph.getNodeCount(); i++) {
|
||||
val node = graph.getNode(i);
|
||||
if(node.getName().equals(name))
|
||||
|
@ -274,21 +274,21 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node) {
|
||||
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getControlDependencies(OnnxProto3.NodeProto node) {
|
||||
public List<String> getControlDependencies(Onnx.NodeProto node) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
|
||||
try {
|
||||
OnnxProto3.ModelProto graphDef = OnnxProto3.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
|
||||
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
|
||||
for(OnnxProto3.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
|
||||
bufferedWriter.write(node.toString());
|
||||
}
|
||||
|
||||
|
@ -316,12 +316,12 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public Map<String,onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
|
||||
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
|
||||
/**
|
||||
* Need to figure out why
|
||||
* gpu_0/conv1_1 isn't present in VGG
|
||||
*/
|
||||
Map<String,onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
|
||||
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
|
||||
for(int i = 0; i < graphProto.getInputCount(); i++) {
|
||||
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
|
||||
}
|
||||
|
@ -356,19 +356,19 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public String translateToSameDiffName(String name, OnnxProto3.NodeProto node) {
|
||||
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
protected void addDummyTensor(String name, Map<String, OnnxProto3.TypeProto.Tensor> to) {
|
||||
OnnxProto3.TensorShapeProto.Dimension dim = OnnxProto3.TensorShapeProto.Dimension.
|
||||
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
|
||||
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
|
||||
newBuilder()
|
||||
.setDimValue(-1)
|
||||
.build();
|
||||
OnnxProto3.TypeProto.Tensor typeProto = OnnxProto3.TypeProto.Tensor.newBuilder()
|
||||
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
|
||||
.setShape(
|
||||
OnnxProto3.TensorShapeProto.newBuilder()
|
||||
Onnx.TensorShapeProto.newBuilder()
|
||||
.addDim(dim)
|
||||
.addDim(dim).build())
|
||||
.build();
|
||||
|
@ -377,23 +377,23 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
@Override
|
||||
public Message.Builder getNewGraphBuilder() {
|
||||
return OnnxProto3.GraphProto.newBuilder();
|
||||
return Onnx.GraphProto.newBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
||||
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||
return OnnxProto3.ModelProto.parseFrom(inputStream).getGraph();
|
||||
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
|
||||
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto, OnnxProto3.TypeProto.Tensor> importState,
|
||||
OpImportOverride<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opImportOverride,
|
||||
OpImportFilter<OnnxProto3.GraphProto, OnnxProto3.NodeProto, OnnxProto3.AttributeProto> opFilter) {
|
||||
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
|
||||
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
|
||||
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
|
||||
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
|
||||
if(differentialFunction == null) {
|
||||
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
|
||||
|
@ -425,13 +425,13 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum) {
|
||||
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
|
||||
return nd4jTypeFromOnnxType(tensorProto.getElemType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor) {
|
||||
return tensor.getElemType() == OnnxProto3.TensorProto.DataType.STRING;
|
||||
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
|
||||
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
|
||||
}
|
||||
|
||||
|
||||
|
@ -440,7 +440,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
* @param dataType the data type to convert
|
||||
* @return the nd4j type for the onnx type
|
||||
*/
|
||||
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType) {
|
||||
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
|
||||
switch (dataType) {
|
||||
case DOUBLE: return DataType.DOUBLE;
|
||||
case FLOAT: return DataType.FLOAT;
|
||||
|
@ -452,8 +452,8 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key) {
|
||||
for(OnnxProto3.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
|
||||
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
|
||||
if(attributeProto.getName().equals(key)) {
|
||||
return attributeProto.getS().toString();
|
||||
}
|
||||
|
@ -463,29 +463,29 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) {
|
||||
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
|
||||
return Longs.toArray(attributeProto.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType) {
|
||||
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType) {
|
||||
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
|
||||
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
|
||||
DataType type = dataTypeForTensor(tensorProto, 0);
|
||||
if(!tensorProto.isInitialized()) {
|
||||
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
|
||||
}
|
||||
|
||||
OnnxProto3.TensorProto tensor = null;
|
||||
Onnx.TensorProto tensor = null;
|
||||
for(int i = 0; i < graph.getInitializerCount(); i++) {
|
||||
val initializer = graph.getInitializer(i);
|
||||
if(initializer.getName().equals(tensorName)) {
|
||||
|
@ -508,7 +508,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
return arr;
|
||||
}
|
||||
|
||||
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
|
||||
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
|
||||
if(tensor == null)
|
||||
return null;
|
||||
|
||||
|
@ -527,7 +527,7 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
}
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromTensor(onnx.OnnxProto3.TypeProto.Tensor tensorProto) {
|
||||
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
|
||||
int dimCount = tensorProto.getShape().getDimCount();
|
||||
if(dimCount >= 2)
|
||||
|
@ -548,11 +548,11 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
/**
|
||||
* Get the shape from a tensor proto.
|
||||
* Note that this is different from {@link #getShapeFromTensor(OnnxProto3.TensorProto)}
|
||||
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
|
||||
* @param tensorProto the tensor to get the shape from
|
||||
* @return
|
||||
*/
|
||||
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto) {
|
||||
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
|
||||
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
|
||||
int dimCount = tensorProto.getDimsCount();
|
||||
if(dimCount >= 2)
|
||||
|
@ -577,74 +577,74 @@ public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto, Onnx
|
|||
|
||||
|
||||
@Override
|
||||
public String getInputFromNode(OnnxProto3.NodeProto node, int index) {
|
||||
public String getInputFromNode(Onnx.NodeProto node, int index) {
|
||||
return node.getInput(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputsFor(OnnxProto3.NodeProto nodeProto) {
|
||||
public int numInputsFor(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getInputCount();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public long[] getShapeFromAttr(OnnxProto3.AttributeProto attr) {
|
||||
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
|
||||
return Longs.toArray(attr.getT().getDimsList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto) {
|
||||
Map<String,OnnxProto3.AttributeProto> proto = new HashMap<>();
|
||||
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
|
||||
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
|
||||
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
|
||||
OnnxProto3.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
|
||||
proto.put(attributeProto.getName(),attributeProto);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName(OnnxProto3.NodeProto nodeProto) {
|
||||
public String getName(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean alreadySeen(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType().contains("Var");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean shouldSkip(OnnxProto3.NodeProto opType) {
|
||||
public boolean shouldSkip(Onnx.NodeProto opType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasShape(OnnxProto3.NodeProto nodeProto) {
|
||||
public boolean hasShape(Onnx.NodeProto nodeProto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] getShape(OnnxProto3.NodeProto nodeProto) {
|
||||
public long[] getShape(Onnx.NodeProto nodeProto) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph) {
|
||||
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOpType(OnnxProto3.NodeProto nodeProto) {
|
||||
public String getOpType(Onnx.NodeProto nodeProto) {
|
||||
return nodeProto.getOpType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto) {
|
||||
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
|
||||
return graphProto.getNodeList();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.imports.graphmapper.tf;
|
||||
|
||||
import com.github.os72.protobuf351.Message;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import com.google.common.primitives.Floats;
|
||||
import com.google.common.primitives.Ints;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package org.nd4j.imports.graphmapper.tf.tensors;
|
||||
|
||||
import com.github.os72.protobuf351.Descriptors;
|
||||
import org.nd4j.shade.protobuf.Descriptors;
|
||||
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
|
||||
import org.bytedeco.javacpp.indexer.HalfIndexer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -205,7 +205,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -200,7 +200,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Data;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -134,7 +134,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
|
||||
|
@ -218,7 +218,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
if (!attributesForNode.containsKey("axes")) {
|
||||
this.dimensions = new int[] { Integer.MAX_VALUE };
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import com.google.common.primitives.Doubles;
|
|||
import com.google.common.primitives.Longs;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -603,7 +603,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -61,7 +61,7 @@ public class NoOp extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -367,7 +367,7 @@ public class If extends DifferentialFunction implements CustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -468,7 +468,7 @@ public class While extends DifferentialFunction implements CustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -122,7 +122,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.nd4j.linalg.api.ops.impl.layers;
|
|||
import lombok.Builder;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -96,7 +96,7 @@ public class Linear extends BaseModule {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -260,7 +260,7 @@ public class AvgPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList();
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -78,7 +78,7 @@ public class AvgPooling3D extends Pooling3D {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
|
@ -139,7 +139,7 @@ public class BatchNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -127,7 +127,7 @@ public class Conv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -247,7 +247,7 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val autoPad = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val dilations = attributesForNode.get("dilations");
|
||||
val dilationY = dilations == null ? 1 : dilations.getIntsList().get(0).intValue();
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -151,7 +151,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
addArgs();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -115,7 +115,7 @@ public class LocalResponseNormalization extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val aAlpha = attributesForNode.get("alpha");
|
||||
val aBeta = attributesForNode.get("beta");
|
||||
val aBias = attributesForNode.get("bias");
|
||||
|
|
|
@ -21,7 +21,7 @@ import lombok.Getter;
|
|||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -221,7 +221,7 @@ public class MaxPooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8();
|
||||
val isSameNode = paddingVal.equals("SAME");
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -78,7 +78,7 @@ public class MaxPooling3D extends Pooling3D {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.Builder;
|
|||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -183,7 +183,7 @@ public class Pooling2D extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isSameNode = attributesForNode.get("auto_pad").getS().equals("SAME");
|
||||
val kernelShape = attributesForNode.get("kernel_shape").getIntsList();
|
||||
val padding = attributesForNode.get("pads").getIntsList();
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
|
||||
|
@ -73,7 +73,7 @@ public class LSTMCell extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -65,7 +65,7 @@ public class SRU extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.layers.recurrent;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
@ -66,7 +66,7 @@ public class SRUCell extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce;
|
|||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -204,7 +204,7 @@ public class Mmul extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||
|
|
|
@ -20,7 +20,7 @@ import com.google.common.primitives.Ints;
|
|||
import com.google.common.primitives.Longs;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
|
@ -283,7 +283,7 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0;
|
||||
val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0;
|
||||
MMulTranspose mMulTranspose = MMulTranspose.builder()
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -163,7 +163,7 @@ public class Concat extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -77,7 +77,7 @@ public class Diag extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -79,7 +79,7 @@ public class DiagPart extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
|
@ -78,7 +78,7 @@ public class Gather extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -65,7 +65,7 @@ public class MergeAvg extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -64,7 +64,7 @@ public class MergeMax extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -66,7 +66,7 @@ public class MergeSum extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -68,7 +68,7 @@ public class ParallelStack extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -66,7 +66,7 @@ public class Rank extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -106,7 +106,7 @@ public class Repeat extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -126,7 +126,7 @@ public class Reshape extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
val shape = new OnnxGraphMapper().getShape(node);
|
||||
this.shape = shape;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import onnx.OnnxMlProto3;
|
||||
import onnx.OnnxMl;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
|
@ -87,7 +87,7 @@ public class Shape extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -93,7 +93,7 @@ public class Stack extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape;
|
|||
|
||||
import com.google.common.primitives.Ints;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
|
@ -156,7 +156,7 @@ public class Transpose extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
if (!attributesForNode.containsKey("perm")) {
|
||||
|
||||
} else
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -127,7 +127,7 @@ public class Unstack extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("No analog found for onnx for " + opName());
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4j.linalg.api.ops.impl.shape.bp;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -71,7 +71,7 @@ public class ConcatBp extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
//No op
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -59,7 +59,7 @@ public class TensorArrayConcat extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -59,7 +59,7 @@ public class TensorArrayGather extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -54,7 +54,7 @@ public class TensorArrayRead extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -52,7 +52,7 @@ public class TensorArrayScatter extends BaseTensorOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -58,7 +58,7 @@ public class TensorArraySize extends BaseTensorOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape.tensorops;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -52,7 +52,7 @@ public class TensorArraySplit extends BaseTensorOp {
|
|||
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -64,7 +64,7 @@ public class ClipByNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.clip;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -77,7 +77,7 @@ public class ClipByValue extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -62,7 +62,7 @@ public class Assign extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -132,7 +132,7 @@ public class CumProd extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
|
@ -133,7 +133,7 @@ public class CumSum extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -80,7 +80,7 @@ public class Fill extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -81,7 +81,7 @@ public class RectifiedTanh extends BaseTransformStrictOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.random.impl;
|
||||
|
||||
import lombok.NonNull;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -75,7 +75,7 @@ public class DropOutInverted extends BaseRandomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
|
||||
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
|
||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.nd4j.linalg.api.ops.random.impl;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package onnx;
|
||||
import "onnx.proto3";
|
||||
import "onnx.proto";
|
||||
|
||||
//
|
||||
// This file contains the proto definitions for OperatorSetProto and
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Rule;
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
<packaging>pom</packaging>
|
||||
<modules>
|
||||
<module>jackson</module>
|
||||
<module>protobuf</module>
|
||||
</modules>
|
||||
|
||||
<properties>
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>nd4j-shade</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>protobuf</artifactId>
|
||||
|
||||
<properties>
|
||||
<skipTestResourceEnforcement>true</skipTestResourceEnforcement>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java</artifactId>
|
||||
<version>3.8.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
<artifactId>protobuf-java-util</artifactId>
|
||||
<version>3.8.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>custom-lifecycle</id>
|
||||
|
||||
<activation>
|
||||
<property><name>!skip.custom.lifecycle</name></property>
|
||||
</activation>
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.portals.jetspeed-2</groupId>
|
||||
<artifactId>jetspeed-mvn-maven-plugin</artifactId>
|
||||
<version>2.3.1</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>compile-and-pack</id>
|
||||
<phase>compile</phase>
|
||||
<goals>
|
||||
<goal>mvn</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.maven.shared</groupId>
|
||||
<artifactId>maven-invoker</artifactId>
|
||||
<version>2.2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<configuration>
|
||||
<targets combine.children="merge">
|
||||
|
||||
<target>
|
||||
<id>create-shaded-jars</id>
|
||||
<dir>@rootdir@/nd4j/nd4j-shade/protobuf/</dir>
|
||||
<goals>clean,compile,package</goals>
|
||||
<properties>
|
||||
<skip.custom.lifecycle>true</skip.custom.lifecycle>
|
||||
</properties>
|
||||
</target>
|
||||
|
||||
</targets>
|
||||
<defaultTarget>create-shaded-jars</defaultTarget>
|
||||
</configuration>
|
||||
</plugin>
|
||||
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
|
||||
<!-- Disable Maven Lint plugin in this module. For some reason it chokes on this module (internal NPE) and we don't need it anyway here -->
|
||||
<plugin>
|
||||
<groupId>com.lewisd</groupId>
|
||||
<artifactId>lint-maven-plugin</artifactId>
|
||||
<version>0.0.11</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>pom-lint</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
|
||||
<!--
|
||||
Use Maven Shade plugin to add a shaded version of the Protobuf dependencies, that can be imported by
|
||||
including this module (org.nd4j.protobuf) as a dependency.
|
||||
The standard com.google.protobuf dependencies will be provided, though are prefixed by org.nd4j.shade.protobuf
|
||||
-->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>${maven-shade-plugin.version}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||
<resource>reference.conf</resource>
|
||||
</transformer>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
|
||||
</transformer>
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
|
||||
<configuration>
|
||||
<!--
|
||||
Important configuration options here:
|
||||
createDependencyReducedPom: remove the shaded artifacts from the module dependencies. Without this, the
|
||||
original dependencies will be shaded, AND still included as transitive deps
|
||||
in the final POM. This is not what we want.
|
||||
shadedArtifactAttached: If true, the shaded artifact will be a separate JAR file for install, with
|
||||
the original un-shaded JAR being separate. With this being set to false,
|
||||
the original JAR will be modified, and no extra jar will be produced.
|
||||
promoteTransitiveDependencies: This will promote the transitive dependencies of the shaded dependencies
|
||||
to direct dependencies. Without this, we need to manually manage the transitive
|
||||
dependencies of the shaded artifacts.
|
||||
|
||||
Note that using <optional>true</optional> in the dependencies also allows the deps to be shaded (and
|
||||
original dependencies to not be included), but does NOT work with promoteTransitiveDependencies
|
||||
-->
|
||||
<shadedArtifactAttached>false</shadedArtifactAttached>
|
||||
<createDependencyReducedPom>true</createDependencyReducedPom>
|
||||
<promoteTransitiveDependencies>true</promoteTransitiveDependencies>
|
||||
|
||||
<artifactSet>
|
||||
<includes>
|
||||
<include>com.google.protobuf:*</include>
|
||||
<include>com.google.protobuf.*:*</include>
|
||||
</includes>
|
||||
</artifactSet>
|
||||
|
||||
<relocations>
|
||||
<!-- Protobuf dependencies -->
|
||||
<relocation>
|
||||
<pattern>com.google.protobuf</pattern>
|
||||
<shadedPattern>org.nd4j.shade.protobuf</shadedPattern>
|
||||
</relocation>
|
||||
</relocations>
|
||||
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<configuration>
|
||||
<forceCreation>true</forceCreation>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>empty-javadoc-jar</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<classifier>javadoc</classifier>
|
||||
<classesDirectory>${basedir}/javadoc</classesDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>empty-sources-jar</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>jar</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<classifier>sources</classifier>
|
||||
<classesDirectory>${basedir}/src</classesDirectory>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-dependency-plugin</artifactId>
|
||||
<version>3.0.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>unpack</id>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>unpack</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<artifactItems>
|
||||
<artifactItem>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>protobuf</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<type>jar</type>
|
||||
<overWrite>false</overWrite>
|
||||
<outputDirectory>${project.build.directory}/classes/</outputDirectory>
|
||||
<includes>**/*.class,**/*.xml</includes>
|
||||
</artifactItem>
|
||||
</artifactItems>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.javacpp.indexer.*;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||
|
||||
import com.github.os72.protobuf351.ByteString;
|
||||
import com.github.os72.protobuf351.InvalidProtocolBufferException;
|
||||
import com.github.os72.protobuf351.util.JsonFormat;
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -638,7 +638,7 @@ public class GraphRunner implements Closeable {
|
|||
|
||||
/**
|
||||
* Convert a json string written out
|
||||
* by {@link com.github.os72.protobuf351.util.JsonFormat}
|
||||
* by {@link org.nd4j.shade.protobuf.util.JsonFormat}
|
||||
* to a {@link org.bytedeco.tensorflow.ConfigProto}
|
||||
* @param json the json to read
|
||||
* @return the config proto to use
|
||||
|
|
Loading…
Reference in New Issue