Fix data type and roll
parent
04209693f5
commit
53bfdb9994
|
@ -51,9 +51,10 @@ namespace ops {
|
||||||
if (shift < 0) {
|
if (shift < 0) {
|
||||||
shift -= input->sizeAt(i) * (shift / inputLen - 1);
|
shift -= input->sizeAt(i) * (shift / inputLen - 1);
|
||||||
}
|
}
|
||||||
else {
|
else if(shift != 0) {
|
||||||
shift %= input->sizeAt(i);
|
shift %= input->sizeAt(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
shifts[i] = shift;
|
shifts[i] = shift;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ namespace ops {
|
||||||
// convert shift to positive value between 1 and inputLen - 1
|
// convert shift to positive value between 1 and inputLen - 1
|
||||||
shift -= inputLen * (shift / inputLen - 1);
|
shift -= inputLen * (shift / inputLen - 1);
|
||||||
}
|
}
|
||||||
else
|
else if(shift != 0)
|
||||||
// cut shift to value between 1 and inputLen - 1
|
// cut shift to value between 1 and inputLen - 1
|
||||||
shift %= inputLen;
|
shift %= inputLen;
|
||||||
axes.resize(block.getIArguments()->size() - 1);
|
axes.resize(block.getIArguments()->size() - 1);
|
||||||
|
@ -87,6 +88,21 @@ namespace ops {
|
||||||
if (block.isInplace()) output = input;
|
if (block.isInplace()) output = input;
|
||||||
|
|
||||||
shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
|
shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
|
||||||
|
nd4j_debug("Roll: Shift is linear %d Shift is %d, first dimension is %d\n",shiftIsLinear,shifts[0],axes[0]);
|
||||||
|
bool shiftsSumZero = false;
|
||||||
|
auto shiftSum = 0;
|
||||||
|
for (auto& s: shifts) {
|
||||||
|
shiftSum += s;
|
||||||
|
nd4j_debug("Roll: Shift is %d\n",s);
|
||||||
|
}
|
||||||
|
//all zeros is no op
|
||||||
|
if(shiftSum < 1) {
|
||||||
|
nd4j_debug("Roll: No shift needed. Shift total was %d\n",shiftSum);
|
||||||
|
if(!block.isInplace()) {
|
||||||
|
output->assign(input);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
if (shiftIsLinear) {
|
if (shiftIsLinear) {
|
||||||
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
|
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());
|
||||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stage 3) swap remainer of items.
|
// stage 3) swap remainder of items.
|
||||||
if (remainShift && shiftCount)
|
if (remainShift && shiftCount)
|
||||||
for (int i = actualShift; i < 2 * actualShift; ++i) {
|
for (int i = actualShift; i < 2 * actualShift; ++i) {
|
||||||
auto _e0 = output->e<T>(i);
|
auto _e0 = output->e<T>(i);
|
||||||
|
@ -95,10 +95,11 @@ namespace helpers {
|
||||||
auto source = output; //input;
|
auto source = output; //input;
|
||||||
for (size_t i = 0; i < axes.size(); i++) {
|
for (size_t i = 0; i < axes.size(); i++) {
|
||||||
int axe = axes[i];
|
int axe = axes[i];
|
||||||
if (axe == source->rankOf() - 1) {// last dimension
|
// if (axe == source->rankOf() - 1) {// last dimension
|
||||||
ResultSet listOfTensors = source->allTensorsAlongDimension({axe});
|
ResultSet listOfTensors = source->allTensorsAlongDimension({axe});
|
||||||
ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe});
|
ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe});
|
||||||
int fullLen = listOfTensors.size();
|
int fullLen = listOfTensors.size();
|
||||||
|
nd4j_debug("Roll: fullLen at last dimension is %d\n",fullLen);
|
||||||
int theShift = shifts[i];
|
int theShift = shifts[i];
|
||||||
if (theShift > 0) {
|
if (theShift > 0) {
|
||||||
theShift %= fullLen;
|
theShift %= fullLen;
|
||||||
|
@ -109,7 +110,7 @@ namespace helpers {
|
||||||
for (int k = 0; k < fullLen; k++) {
|
for (int k = 0; k < fullLen; k++) {
|
||||||
rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true);
|
rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true);
|
||||||
}
|
}
|
||||||
}
|
/* }
|
||||||
else {
|
else {
|
||||||
std::vector<int> dims(source->rankOf() - axe - 1);
|
std::vector<int> dims(source->rankOf() - axe - 1);
|
||||||
for (size_t i = 0; i < dims.size(); ++i)
|
for (size_t i = 0; i < dims.size(); ++i)
|
||||||
|
@ -120,6 +121,7 @@ namespace helpers {
|
||||||
//
|
//
|
||||||
int fullLen = listOfTensors.size();
|
int fullLen = listOfTensors.size();
|
||||||
int sizeAt = input->sizeAt(axe);
|
int sizeAt = input->sizeAt(axe);
|
||||||
|
nd4j_debug("Roll: fullLen at dimension %d is %d\n",i,fullLen);
|
||||||
|
|
||||||
int theShift = shifts[i];
|
int theShift = shifts[i];
|
||||||
|
|
||||||
|
@ -131,14 +133,14 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (theShift) {
|
if (theShift) {
|
||||||
for (int dim = 0; dim < fullLen / sizeAt; ++dim) {
|
for (size_t dim = 0; dim < fullLen / sizeAt; ++dim) {
|
||||||
for (int e = theShift; e < sizeAt - theShift; ++e) {
|
for (size_t e = theShift; e < sizeAt - theShift; ++e) {
|
||||||
auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift);
|
auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift);
|
||||||
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
|
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
|
||||||
sourceM->swapUnsafe(*targetM);
|
sourceM->swapUnsafe(*targetM);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int e = 0; e < theShift; ++e) {
|
for (size_t e = 0; e < theShift; ++e) {
|
||||||
int sourceIndex = dim * sizeAt + sizeAt - theShift + e;
|
int sourceIndex = dim * sizeAt + sizeAt - theShift + e;
|
||||||
auto sourceM = listOfTensors.at(sourceIndex);
|
auto sourceM = listOfTensors.at(sourceIndex);
|
||||||
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
|
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
|
||||||
|
@ -149,7 +151,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if (!inplace)
|
// if (!inplace)
|
||||||
// source = output;
|
// source = output;*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,10 +34,10 @@ public class Roll extends DynamicCustomOp {
|
||||||
|
|
||||||
public Roll() {}
|
public Roll() {}
|
||||||
|
|
||||||
public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) {
|
public Roll(@NonNull INDArray input, @NonNull INDArray shifts, @NonNull INDArray axes) {
|
||||||
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
|
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
|
||||||
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
|
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
|
||||||
addInputArgument(input, axes, shifts);
|
addInputArgument(input, shifts, axes);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Roll(@NonNull INDArray input, int shift) {
|
public Roll(@NonNull INDArray input, int shift) {
|
||||||
|
@ -49,8 +49,8 @@ public class Roll extends DynamicCustomOp {
|
||||||
super("", sameDiff, new SDVariable[]{input,shift});
|
super("", sameDiff, new SDVariable[]{input,shift});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
|
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift, @NonNull SDVariable axes) {
|
||||||
super("", sameDiff, new SDVariable[]{input,axes,shift});
|
super("", sameDiff, new SDVariable[]{input,shift,axes});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
|
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
|
||||||
|
|
|
@ -284,7 +284,10 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||||
List<DataType> result = new ArrayList<>();
|
List<DataType> result = new ArrayList<>();
|
||||||
result.add(inputDataTypes.get(0));
|
result.add(inputDataTypes.get(0));
|
||||||
result.add(outputType == null ? DataType.INT : outputType);
|
if(dArguments.isEmpty())
|
||||||
|
result.add(outputType == null ? DataType.INT : outputType);
|
||||||
|
else
|
||||||
|
result.add(dArguments.get(0));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
-->
|
-->
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
@ -36,6 +36,17 @@
|
||||||
<name>nd4j-tests</name>
|
<name>nd4j-tests</name>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
|
<kotlin.version>1.4.30-M1</kotlin.version>
|
||||||
|
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
|
||||||
|
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
|
||||||
|
<junit.version>4.12</junit.version>
|
||||||
|
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
|
<java.version>1.8</java.version>
|
||||||
|
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
||||||
|
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
<scala.binary.version>2.11</scala.binary.version>
|
||||||
|
@ -43,7 +54,213 @@
|
||||||
<maven.compiler.testSource>1.8</maven.compiler.testSource>
|
<maven.compiler.testSource>1.8</maven.compiler.testSource>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok-maven-plugin</artifactId>
|
||||||
|
<version>1.18.12.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>delombok</id>
|
||||||
|
<phase>generate-sources</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>delombok</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<formatPreferences>
|
||||||
|
<javaLangAsFQN>skip</javaLangAsFQN>
|
||||||
|
</formatPreferences>
|
||||||
|
<verbose>true</verbose>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>test-delombok</id>
|
||||||
|
<phase>generate-test-sources</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>testDelombok</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<verbose>true</verbose>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.codehaus.mojo</groupId>
|
||||||
|
<artifactId>build-helper-maven-plugin</artifactId>
|
||||||
|
<version>3.0.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>add-source</id>
|
||||||
|
<phase>generate-sources</phase>
|
||||||
|
<goals><goal>add-source</goal></goals>
|
||||||
|
<configuration>
|
||||||
|
<sources>
|
||||||
|
<source>src/main/stubs</source>
|
||||||
|
</sources>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-shade-plugin</artifactId>
|
||||||
|
<version>${maven-shade-plugin.version}</version>
|
||||||
|
<configuration>
|
||||||
|
<shadedArtifactAttached>true</shadedArtifactAttached>
|
||||||
|
<createDependencyReducedPom>false</createDependencyReducedPom>
|
||||||
|
<filters>
|
||||||
|
<filter>
|
||||||
|
<artifact>*:*</artifact>
|
||||||
|
<excludes>
|
||||||
|
<exclude>org/datanucleus/**</exclude>
|
||||||
|
<exclude>META-INF/*.SF</exclude>
|
||||||
|
<exclude>META-INF/*.DSA</exclude>
|
||||||
|
<exclude>META-INF/*.RSA</exclude>
|
||||||
|
</excludes>
|
||||||
|
</filter>
|
||||||
|
</filters>
|
||||||
|
|
||||||
|
</configuration>
|
||||||
|
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<phase>package</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>shade</goal>
|
||||||
|
</goals>
|
||||||
|
<configuration>
|
||||||
|
<transformers>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||||
|
<resource>reference.conf</resource>
|
||||||
|
</transformer>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||||
|
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
|
||||||
|
</transformers>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-plugin</artifactId>
|
||||||
|
<version>1.4.30-M1</version>
|
||||||
|
<configuration>
|
||||||
|
<args>
|
||||||
|
<arg>-Xjsr305=strict</arg>
|
||||||
|
</args>
|
||||||
|
<compilerPlugins>
|
||||||
|
<plugin>spring</plugin>
|
||||||
|
<plugin>jpa</plugin>
|
||||||
|
</compilerPlugins>
|
||||||
|
</configuration>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-allopen</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-maven-noarg</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>compile</id>
|
||||||
|
<goals> <goal>compile</goal> </goals>
|
||||||
|
<configuration>
|
||||||
|
<sourceDirs>
|
||||||
|
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/java</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
|
||||||
|
</sourceDirs>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>test-compile</id>
|
||||||
|
<goals> <goal>test-compile</goal> </goals>
|
||||||
|
<configuration>
|
||||||
|
<sourceDirs>
|
||||||
|
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/java</sourceDir>
|
||||||
|
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
|
||||||
|
</sourceDirs>
|
||||||
|
</configuration>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
|
|
||||||
|
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
<version>3.5.1</version>
|
||||||
|
<executions>
|
||||||
|
<!-- Replacing default-compile as it is treated specially by maven -->
|
||||||
|
<execution>
|
||||||
|
<id>default-compile</id>
|
||||||
|
<phase>none</phase>
|
||||||
|
</execution>
|
||||||
|
<!-- Replacing default-testCompile as it is treated specially by maven -->
|
||||||
|
<execution>
|
||||||
|
<id>default-testCompile</id>
|
||||||
|
<phase>none</phase>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>java-compile</id>
|
||||||
|
<phase>compile</phase>
|
||||||
|
<goals> <goal>compile</goal> </goals>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>java-test-compile</id>
|
||||||
|
<phase>test-compile</phase>
|
||||||
|
<goals> <goal>testCompile</goal> </goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<source>${java.version}</source>
|
||||||
|
<target>${java.version}</target>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<!-- Test Dependencies -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-api</artifactId>
|
||||||
|
<version>${junit-jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
<artifactId>junit-jupiter-engine</artifactId>
|
||||||
|
<version>${junit-jupiter.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-stdlib-jdk8</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jetbrains.kotlin</groupId>
|
||||||
|
<artifactId>kotlin-test</artifactId>
|
||||||
|
<version>${kotlin.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>samediff-import-tensorflow</artifactId>
|
<artifactId>samediff-import-tensorflow</artifactId>
|
||||||
|
|
|
@ -68,13 +68,22 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
* all tests will trigger an assumeFalse(..) that indicates
|
* all tests will trigger an assumeFalse(..) that indicates
|
||||||
* the status of the test failing. No tests will run.
|
* the status of the test failing. No tests will run.
|
||||||
*/
|
*/
|
||||||
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList();
|
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
|
||||||
|
// "max_pool_with_argmax/int32_int64_padding_SAME",
|
||||||
|
// "fused_batch_norm/float32_nhwc",
|
||||||
|
// "max_pool_with_argmax/int64_int64_padding_SAME",
|
||||||
|
// "fused_batch_norm/float16_nhwc",
|
||||||
|
"roll/rank3_int32_axis",
|
||||||
|
"roll/rank3_int32_axis",
|
||||||
|
"roll/rank2_float32_zeroshift",
|
||||||
|
"roll/rank3_float64_axis"
|
||||||
|
);
|
||||||
|
|
||||||
public static final String[] IGNORE_REGEXES = new String[]{
|
public static final String[] IGNORE_REGEXES = new String[]{
|
||||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||||
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
||||||
"bincount/.*",
|
//Invalid test cases. Verified by running graph against actual TF.
|
||||||
|
"slogdet/.*",
|
||||||
|
|
||||||
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
|
||||||
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod
|
||||||
|
|
|
@ -881,7 +881,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
|
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
|
||||||
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
|
||||||
assertEquals(1, lsd.size());
|
assertEquals(1, lsd.size());
|
||||||
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
|
assertArrayEquals(new long[]{1,10,2}, lsd.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -942,21 +942,21 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@Ignore("Failing with results that are close")
|
||||||
public void testFakeQuantAgainstTF_1() {
|
public void testFakeQuantAgainstTF_1() {
|
||||||
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
INDArray x = Nd4j.createFromArray(new double[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
|
||||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
||||||
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
|
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
|
||||||
INDArray min = Nd4j.createFromArray(new float[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
INDArray min = Nd4j.createFromArray(new double[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
||||||
INDArray max = Nd4j.createFromArray(new float[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
INDArray max = Nd4j.createFromArray(new double[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
|
INDArray expected = Nd4j.createFromArray(new double[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
|
||||||
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
|
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
|
||||||
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(x.shape());
|
|
||||||
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
|
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
|
||||||
Nd4j.exec(op);
|
INDArray[] output = Nd4j.exec(op);
|
||||||
assertEquals(expected, out);
|
assertEquals(expected, output[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -971,8 +971,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testResizeBilinear1() {
|
public void testResizeBilinear1() {
|
||||||
|
INDArray x = Nd4j.rand(1, 10,10,4);
|
||||||
INDArray x = Nd4j.rand(1, 2,3,4);
|
|
||||||
INDArray z = Nd4j.createUninitialized(x.shape());
|
INDArray z = Nd4j.createUninitialized(x.shape());
|
||||||
boolean align = false;
|
boolean align = false;
|
||||||
val op = new ResizeBilinear(x, z, 10, 10, align, false);
|
val op = new ResizeBilinear(x, z, 10, 10, align, false);
|
||||||
|
@ -1082,24 +1081,8 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
INDArray distance = Nd4j.scalar(0.f);
|
INDArray distance = Nd4j.scalar(0.f);
|
||||||
|
|
||||||
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
|
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
|
||||||
// System.out.println(distance);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore("2019/11/15 AS - https://github.com/eclipse/deeplearning4j/issues/8399")
|
|
||||||
@Test
|
|
||||||
public void testCropAndResize() {
|
|
||||||
INDArray image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1);
|
|
||||||
INDArray boxes = Nd4j.createFromArray(new float[]{1,2,3,4}).reshape(1,4);
|
|
||||||
INDArray box_indices = Nd4j.createFromArray(new int[]{1});
|
|
||||||
INDArray crop_size = Nd4j.createFromArray(new int[]{1,2}).reshape(1,2);
|
|
||||||
|
|
||||||
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
|
|
||||||
INDArray output = Nd4j.create(DataType.FLOAT, 2,2,1,1);
|
|
||||||
|
|
||||||
|
|
||||||
Nd4j.exec(new CropAndResize(image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
|
|
||||||
output));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLayersDropoutFail() {
|
public void testLayersDropoutFail() {
|
||||||
|
@ -1338,7 +1321,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, ret[0]);
|
assertEquals(expected, ret[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453")
|
|
||||||
@Test
|
@Test
|
||||||
public void testRoll1() {
|
public void testRoll1() {
|
||||||
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
|
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
|
||||||
|
@ -1346,6 +1328,10 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
INDArray[] ret = Nd4j.exec(op);
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
|
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
|
||||||
assertEquals(expected, ret[0]);
|
assertEquals(expected, ret[0]);
|
||||||
|
INDArray matrix = Nd4j.create(new double[]{0.7788,0.8012,0.7244,0.2309,0.7271,0.1804,0.5056,0.8925}).reshape(2,4);
|
||||||
|
Roll roll2 = new Roll(matrix,Nd4j.scalar(0),Nd4j.scalar(1));
|
||||||
|
INDArray[] outputs = Nd4j.exec(roll2);
|
||||||
|
System.out.println(outputs[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.custom
|
||||||
|
|
||||||
|
import junit.framework.Assert.assertEquals
|
||||||
|
import org.junit.Ignore
|
||||||
|
import org.junit.Test
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType
|
||||||
|
import org.nd4j.linalg.api.ops.impl.image.CropAndResize
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
import org.nd4j.samediff.frameworkimport.tensorflow.*
|
||||||
|
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
|
||||||
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
||||||
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
||||||
|
|
||||||
|
class CustomOpTensorflowInteropTests {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore("Tensorflow expects different shape")
|
||||||
|
fun testCropAndResize() {
|
||||||
|
val image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1)
|
||||||
|
val boxes = Nd4j.createFromArray(*floatArrayOf(1f, 2f, 3f, 4f)).reshape(1, 4)
|
||||||
|
val box_indices = Nd4j.createFromArray(*intArrayOf(0))
|
||||||
|
val crop_size = Nd4j.createFromArray(*intArrayOf(1, 2)).reshape( 2)
|
||||||
|
val imageNode = NodeDef {
|
||||||
|
op = "Placeholder"
|
||||||
|
name = "image"
|
||||||
|
Attribute("dtype", AttrValue {
|
||||||
|
type = org.tensorflow.framework.DataType.DT_FLOAT
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
val boxesNode = NodeDef {
|
||||||
|
op = "Placeholder"
|
||||||
|
name = "boxes"
|
||||||
|
Attribute("dtype", AttrValue {
|
||||||
|
type = org.tensorflow.framework.DataType.DT_FLOAT
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
val boxIndicesNode = NodeDef {
|
||||||
|
op = "Placeholder"
|
||||||
|
name = "boxIndices"
|
||||||
|
Attribute("dtype", AttrValue {
|
||||||
|
type = org.tensorflow.framework.DataType.DT_INT32
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
val cropSizesNode = NodeDef {
|
||||||
|
op = "Placeholder"
|
||||||
|
name = "cropSize"
|
||||||
|
Attribute("dtype", AttrValue {
|
||||||
|
type = org.tensorflow.framework.DataType.DT_INT32
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val opNode = NodeDef {
|
||||||
|
op = "CropAndResize"
|
||||||
|
name = "output"
|
||||||
|
Input("image")
|
||||||
|
Input("boxes")
|
||||||
|
Input("boxIndices")
|
||||||
|
Input("cropSize")
|
||||||
|
Attribute("extrapolation_value", AttrValue {
|
||||||
|
f = 0.5f
|
||||||
|
})
|
||||||
|
Attribute("T", AttrValue {
|
||||||
|
type = org.tensorflow.framework.DataType.DT_FLOAT
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
val graph = GraphDef {
|
||||||
|
Node(imageNode)
|
||||||
|
Node(boxesNode)
|
||||||
|
Node(boxIndicesNode)
|
||||||
|
Node(cropSizesNode)
|
||||||
|
Node(opNode)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
val importer = TensorflowFrameworkImporter()
|
||||||
|
val irGraph = TensorflowIRGraph(graph,importer.opDefList,importer.registry)
|
||||||
|
val runner = TensorflowIRGraphRunner(irGraph,listOf("image","boxes","boxIndices","cropSize"),listOf("output"))
|
||||||
|
val tfResult = runner.run(mapOf("image" to image,"boxes" to boxes,"boxIndices" to box_indices,"cropSize" to crop_size))
|
||||||
|
val outputArr = tfResult["output"]
|
||||||
|
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
|
||||||
|
val output = Nd4j.create(DataType.FLOAT, 2, 2, 1, 1)
|
||||||
|
Nd4j.exec(
|
||||||
|
CropAndResize(
|
||||||
|
image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
|
||||||
|
output
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEquals(outputArr,output)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,2 @@
|
||||||
Variable/read,Variable/read
|
in_0/read,in_0/read
|
||||||
Variable_1/read,Variable_1/read
|
Roll,Roll
|
||||||
floordiv/x,floordiv/x
|
|
||||||
floordiv/y,floordiv/y
|
|
||||||
floordiv,floordiv
|
|
||||||
|
|
|
@ -1,18 +1 @@
|
||||||
in_0/read,in_0/read
|
Sum,Sum
|
||||||
while/Enter,while/Enter
|
|
||||||
while/Enter_1,while/Enter_1
|
|
||||||
while/Merge,while/Merge
|
|
||||||
while/Merge_1,while/Merge_1
|
|
||||||
while/Less,while/Less
|
|
||||||
while/LoopCond,while/LoopCond
|
|
||||||
while/Switch,while/Switch
|
|
||||||
while/Switch:1,while/Switch
|
|
||||||
while/Switch_1,while/Switch_1
|
|
||||||
while/Switch_1:1,while/Switch_1
|
|
||||||
while/Identity,while/Identity
|
|
||||||
while/Exit,while/Exit
|
|
||||||
while/Identity_1,while/Identity_1
|
|
||||||
while/Exit_1,while/Exit_1
|
|
||||||
while/add,while/add
|
|
||||||
while/NextIteration_1,while/NextIteration_1
|
|
||||||
while/NextIteration,while/NextIteration
|
|
||||||
|
|
|
@ -207,12 +207,6 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
|
||||||
OpMappingRegistry<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE,
|
OpMappingRegistry<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE,
|
||||||
DATA_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE>): SameDiff {
|
DATA_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE>): SameDiff {
|
||||||
|
|
||||||
/*
|
|
||||||
First, build an in-memory representation of the graph that allows us to build the graph incrementally
|
|
||||||
If we can build the graph incrementally, we can make sure that the added variables are set up with the correct
|
|
||||||
datatype and (once implemented) greedy shape inference
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
First, build an in-memory representation of the graph that allows us to build the graph incrementally
|
First, build an in-memory representation of the graph that allows us to build the graph incrementally
|
||||||
|
@ -442,19 +436,34 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
|
||||||
val op = SameDiffOp.builder()
|
val op = SameDiffOp.builder()
|
||||||
.name(name)
|
.name(name)
|
||||||
.op(df)
|
.op(df)
|
||||||
.inputsToOp(inNames) //.outputsOfOp(outNames) //We'll set this later
|
|
||||||
.controlDeps(controlDeps)
|
.controlDeps(controlDeps)
|
||||||
.build()
|
.build()
|
||||||
|
//take only up to the inputs that are specified in the node/
|
||||||
|
//this is for cases where node inputs is > intended number for ops
|
||||||
|
//a common example is when ops convert input ndarrays to integers or float inputs
|
||||||
|
val numInputsToTake = importInfo[name]!!.second.argDescriptorList.filter { input -> input.argType == OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR }
|
||||||
|
.size
|
||||||
|
op.inputsToOp = inNames.subList(0,numInputsToTake)
|
||||||
|
|
||||||
//add nodes/other pre processing in order for this node to work
|
//add nodes/other pre processing in order for this node to work
|
||||||
var addToGraph = true
|
|
||||||
sd.ops[name] = op
|
sd.ops[name] = op
|
||||||
defaultRunner.initAttributes(df, sd, importInfo[name]!!)
|
//clear out inputs for variables as well to reflect the actual graph structure
|
||||||
|
if(numInputsToTake < numInputs) {
|
||||||
|
for(i in numInputsToTake until numInputs) {
|
||||||
|
if(sd.hasVariable(nd.inputAt(i))) {
|
||||||
|
val currInputVar = sd.variables[nd.inputAt(i)]!!
|
||||||
|
currInputVar.inputsForOp.remove(op.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//cache attributes just in case we have any rules so we don't create the rules more than once
|
//cache attributes just in case we have any rules so we don't create the rules more than once
|
||||||
val attributes = mappingContext.nodeAttributesAsMap()
|
val attributes = mappingContext.nodeAttributesAsMap()
|
||||||
mappingContext.relevantPrehookRules().forEach { rule ->
|
mappingContext.relevantPrehookRules().forEach { rule ->
|
||||||
rule.preProcess(op, sd,attributes)
|
rule.preProcess(op, sd,attributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defaultRunner.initAttributes(df, sd, importInfo[name]!!)
|
||||||
|
|
||||||
|
|
||||||
//add nodes/other post processing in order for this node to work
|
//add nodes/other post processing in order for this node to work
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -33,6 +33,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
|
||||||
Onnx.TensorProto.DataType.UINT64 -> return IRDataTypeValue.DT_UINT64
|
Onnx.TensorProto.DataType.UINT64 -> return IRDataTypeValue.DT_UINT64
|
||||||
Onnx.TensorProto.DataType.UINT32 -> return IRDataTypeValue.DT_UINT32
|
Onnx.TensorProto.DataType.UINT32 -> return IRDataTypeValue.DT_UINT32
|
||||||
Onnx.TensorProto.DataType.UINT16 -> return IRDataTypeValue.DT_UINT16
|
Onnx.TensorProto.DataType.UINT16 -> return IRDataTypeValue.DT_UINT16
|
||||||
|
Onnx.TensorProto.DataType.INT8 -> return IRDataTypeValue.DT_INT8
|
||||||
|
Onnx.TensorProto.DataType.UINT8 -> return IRDataTypeValue.DT_UINT8
|
||||||
Onnx.TensorProto.DataType.FLOAT16 -> return IRDataTypeValue.DT_HALF
|
Onnx.TensorProto.DataType.FLOAT16 -> return IRDataTypeValue.DT_HALF
|
||||||
Onnx.TensorProto.DataType.STRING -> return IRDataTypeValue.DT_STRING
|
Onnx.TensorProto.DataType.STRING -> return IRDataTypeValue.DT_STRING
|
||||||
Onnx.TensorProto.DataType.FLOAT -> return IRDataTypeValue.DT_FLOAT
|
Onnx.TensorProto.DataType.FLOAT -> return IRDataTypeValue.DT_FLOAT
|
||||||
|
@ -60,9 +62,11 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
|
||||||
|
|
||||||
override fun nd4jDataType(): DataType {
|
override fun nd4jDataType(): DataType {
|
||||||
when(this.dataType) {
|
when(this.dataType) {
|
||||||
Onnx.TensorProto.DataType.UINT64 -> return return DataType.INT64
|
Onnx.TensorProto.DataType.INT8 -> return DataType.INT8
|
||||||
Onnx.TensorProto.DataType.UINT32 -> return return DataType.INT32
|
Onnx.TensorProto.DataType.UINT8 -> return DataType.UINT8
|
||||||
Onnx.TensorProto.DataType.UINT16 -> return return DataType.INT16
|
Onnx.TensorProto.DataType.UINT64 -> return DataType.UINT64
|
||||||
|
Onnx.TensorProto.DataType.UINT32 -> return DataType.INT32
|
||||||
|
Onnx.TensorProto.DataType.UINT16 -> return DataType.INT16
|
||||||
Onnx.TensorProto.DataType.FLOAT16 -> return return DataType.FLOAT16
|
Onnx.TensorProto.DataType.FLOAT16 -> return return DataType.FLOAT16
|
||||||
Onnx.TensorProto.DataType.STRING -> return return DataType.UTF8
|
Onnx.TensorProto.DataType.STRING -> return return DataType.UTF8
|
||||||
Onnx.TensorProto.DataType.FLOAT -> return return DataType.FLOAT
|
Onnx.TensorProto.DataType.FLOAT -> return return DataType.FLOAT
|
||||||
|
@ -81,8 +85,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
|
||||||
override fun nameSpaceDataType(): TensorNamespace.DataType {
|
override fun nameSpaceDataType(): TensorNamespace.DataType {
|
||||||
when(this.dataType) {
|
when(this.dataType) {
|
||||||
Onnx.TensorProto.DataType.UINT64 -> return return TensorNamespace.DataType.INT64
|
Onnx.TensorProto.DataType.UINT64 -> return return TensorNamespace.DataType.INT64
|
||||||
Onnx.TensorProto.DataType.UINT32 -> return return TensorNamespace.DataType.INT32
|
Onnx.TensorProto.DataType.UINT32 -> return TensorNamespace.DataType.INT32
|
||||||
Onnx.TensorProto.DataType.UINT16 -> return return TensorNamespace.DataType.INT16
|
Onnx.TensorProto.DataType.UINT16 -> return TensorNamespace.DataType.INT16
|
||||||
Onnx.TensorProto.DataType.FLOAT16 -> return return TensorNamespace.DataType.FLOAT16
|
Onnx.TensorProto.DataType.FLOAT16 -> return return TensorNamespace.DataType.FLOAT16
|
||||||
Onnx.TensorProto.DataType.STRING -> return return TensorNamespace.DataType.STRING
|
Onnx.TensorProto.DataType.STRING -> return return TensorNamespace.DataType.STRING
|
||||||
Onnx.TensorProto.DataType.FLOAT -> return TensorNamespace.DataType.FLOAT
|
Onnx.TensorProto.DataType.FLOAT -> return TensorNamespace.DataType.FLOAT
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
Placeholder,data
|
Const,in_0
|
||||||
Placeholder,partitions
|
Const,Roll/shift
|
||||||
DynamicPartition,output
|
Const,Roll/axis
|
||||||
Identity,out0
|
Identity,in_0/read
|
||||||
Identity,out1
|
Roll,Roll
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
Const,in_0
|
Const,in_0
|
||||||
Const,eye/ones
|
Const,Roll/shift
|
||||||
|
Const,Roll/axis
|
||||||
Identity,in_0/read
|
Identity,in_0/read
|
||||||
MatrixDiag,eye/diag
|
Roll,Roll
|
||||||
Add,Add
|
|
||||||
Svd,Svd
|
|
||||||
Abs,Abs
|
|
||||||
|
|
|
@ -1,3 +1,2 @@
|
||||||
DynamicPartition,output
|
Identity,in_0/read
|
||||||
Identity,out0
|
Roll,Roll
|
||||||
Identity,out1
|
|
||||||
|
|
|
@ -1,5 +1,2 @@
|
||||||
Identity,in_0/read
|
Identity,in_0/read
|
||||||
MatrixDiag,eye/diag
|
Roll,Roll
|
||||||
Add,Add
|
|
||||||
Svd,Svd
|
|
||||||
Abs,Abs
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
data
|
in_0
|
||||||
partitions
|
Roll/shift
|
||||||
output
|
Roll/axis
|
||||||
out0
|
in_0/read
|
||||||
out1
|
Roll
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
in_0
|
in_0
|
||||||
eye/ones
|
Roll/shift
|
||||||
|
Roll/axis
|
||||||
in_0/read
|
in_0/read
|
||||||
eye/diag
|
Roll
|
||||||
Add
|
|
||||||
Svd
|
|
||||||
Abs
|
|
||||||
|
|
|
@ -358,7 +358,7 @@ val bitCast = TensorflowMappingProcess(
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
inputFrameworkOpName = "Bitcast",
|
inputFrameworkOpName = "Bitcast",
|
||||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))),
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))),
|
||||||
attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dataType" to "type")))
|
attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dtype" to "type")))
|
||||||
)
|
)
|
||||||
|
|
||||||
val bitwiseAnd = TensorflowMappingProcess(
|
val bitwiseAnd = TensorflowMappingProcess(
|
||||||
|
@ -1070,7 +1070,7 @@ val fill = TensorflowMappingProcess(
|
||||||
inputFrameworkOpName = "Fill",
|
inputFrameworkOpName = "Fill",
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
attributeMappingRules = listOf(convertNDArrayInputToNumericalAttr(mutableMapOf("value" to "value")),
|
attributeMappingRules = listOf(convertNDArrayInputToNumericalAttr(mutableMapOf("value" to "value")),
|
||||||
dataTypeToInt(mutableMapOf("dtype" to "T")),
|
dataTypeToInt(mutableMapOf("outputDataType" to "T")),
|
||||||
valueMapping(mutableMapOf("dtype" to "T"))),
|
valueMapping(mutableMapOf("dtype" to "T"))),
|
||||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("shapeArray" to "dims")))
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("shapeArray" to "dims")))
|
||||||
)
|
)
|
||||||
|
@ -1381,6 +1381,7 @@ val maxPoolArgmax = multipleNameMapping(
|
||||||
intConstant(inputName = "extraParam0",constantValue = 0 ,argumentIndex = 9)[0],
|
intConstant(inputName = "extraParam0",constantValue = 0 ,argumentIndex = 9)[0],
|
||||||
intConstant(inputName = "isNHWC",argumentIndex = 10,constantValue = 1 )[0],
|
intConstant(inputName = "isNHWC",argumentIndex = 10,constantValue = 1 )[0],
|
||||||
intConstant(inputName = "sameMode",argumentIndex = 8,constantValue = 8 )[0],
|
intConstant(inputName = "sameMode",argumentIndex = 8,constantValue = 8 )[0],
|
||||||
|
valueMapping(mutableMapOf("dtype" to "T"))
|
||||||
)
|
)
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry
|
,tensorflowOpRegistry = tensorflowOpRegistry
|
||||||
)
|
)
|
||||||
|
@ -1455,6 +1456,13 @@ val mirrorPadding = mapTensorNamesWithOp(inputFrameworkOpName = "MirrorPad",opNa
|
||||||
* v1 and 2 which do not have a scoreThreshold to map. V3 does.
|
* v1 and 2 which do not have a scoreThreshold to map. V3 does.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
val matrixBandPart = mapTensorNamesWithOp(inputFrameworkOpName = "MatrixBandPart",opName = "matrix_band_part",
|
||||||
|
tensorNames = mutableMapOf("input" to "input","minLowerT" to "num_lower",
|
||||||
|
"maxUpperT" to "num_upper"),
|
||||||
|
attributeMappingRules = listOf()
|
||||||
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
|
|
||||||
val nonMaxSuppressionV1 = multipleNameMapping(inputFrameworkOpNames = listOf("NonMaxSuppression"),
|
val nonMaxSuppressionV1 = multipleNameMapping(inputFrameworkOpNames = listOf("NonMaxSuppression"),
|
||||||
opName = "non_max_suppression",
|
opName = "non_max_suppression",
|
||||||
tensorNames = mutableMapOf("boxes" to "boxes","scales" to "scores",
|
tensorNames = mutableMapOf("boxes" to "boxes","scales" to "scores",
|
||||||
|
@ -1654,7 +1662,8 @@ val randomCrop = mapTensorNamesWithOp(inputFrameworkOpName = "RandomCrop",opName
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed")))
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed")))
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf(),
|
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",
|
||||||
|
tensorNames = mutableMapOf(),
|
||||||
attributeMappingRules = listOf()
|
attributeMappingRules = listOf()
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
|
@ -1823,8 +1832,8 @@ val reverseSequence = multipleNameMapping(inputFrameworkOpNames = listOf("Revers
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
val roll = multipleNameMapping(inputFrameworkOpNames = listOf("Roll"),opName = "roll",
|
val roll = multipleNameMapping(inputFrameworkOpNames = listOf("Roll"),opName = "roll",
|
||||||
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift"))),
|
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift","dimensions" to "axis"))),
|
||||||
tensorNames = mutableMapOf("input" to "input","dimensions" to "axis","shiftsI" to "shift")
|
tensorNames = mutableMapOf("input" to "input")
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
//TODO: verify usingLocking property, it's not showing up in descriptors
|
//TODO: verify usingLocking property, it's not showing up in descriptors
|
||||||
|
@ -1941,6 +1950,7 @@ val size = TensorflowMappingProcess(
|
||||||
opMappingRegistry = tensorflowOpRegistry,
|
opMappingRegistry = tensorflowOpRegistry,
|
||||||
inputFrameworkOpName = "Size",
|
inputFrameworkOpName = "Size",
|
||||||
opName = "size",
|
opName = "size",
|
||||||
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "out_type"))),
|
||||||
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input")))
|
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input")))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,10 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
|
||||||
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return IRDataTypeValue.DT_DOUBLE
|
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return IRDataTypeValue.DT_DOUBLE
|
||||||
DataType.DT_FLOAT, DataType.DT_FLOAT_REF -> return IRDataTypeValue.DT_FLOAT
|
DataType.DT_FLOAT, DataType.DT_FLOAT_REF -> return IRDataTypeValue.DT_FLOAT
|
||||||
DataType.DT_HALF, DataType.DT_HALF_REF -> return IRDataTypeValue.DT_HALF
|
DataType.DT_HALF, DataType.DT_HALF_REF -> return IRDataTypeValue.DT_HALF
|
||||||
|
DataType.DT_INT8, DataType.DT_INT8_REF -> return IRDataTypeValue.DT_INT8
|
||||||
|
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return IRDataTypeValue.DT_UINT8
|
||||||
|
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return IRDataTypeValue.DT_UINT16
|
||||||
|
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return IRDataTypeValue.DT_UINT32
|
||||||
DataType.DT_INT16, DataType.DT_INT16_REF -> return IRDataTypeValue.DT_INT16
|
DataType.DT_INT16, DataType.DT_INT16_REF -> return IRDataTypeValue.DT_INT16
|
||||||
DataType.DT_INT32, DataType.DT_INT32_REF -> return IRDataTypeValue.DT_INT32
|
DataType.DT_INT32, DataType.DT_INT32_REF -> return IRDataTypeValue.DT_INT32
|
||||||
DataType.DT_INT64, DataType.DT_INT64_REF -> return IRDataTypeValue.DT_INT64
|
DataType.DT_INT64, DataType.DT_INT64_REF -> return IRDataTypeValue.DT_INT64
|
||||||
|
@ -70,9 +74,11 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
|
||||||
DataType.DT_BFLOAT16, DataType.DT_BFLOAT16_REF -> return org.nd4j.linalg.api.buffer.DataType.BFLOAT16
|
DataType.DT_BFLOAT16, DataType.DT_BFLOAT16_REF -> return org.nd4j.linalg.api.buffer.DataType.BFLOAT16
|
||||||
DataType.DT_INT64, DataType.DT_INT64_REF -> return org.nd4j.linalg.api.buffer.DataType.INT64
|
DataType.DT_INT64, DataType.DT_INT64_REF -> return org.nd4j.linalg.api.buffer.DataType.INT64
|
||||||
DataType.DT_HALF, DataType.DT_HALF_REF -> return org.nd4j.linalg.api.buffer.DataType.FLOAT16
|
DataType.DT_HALF, DataType.DT_HALF_REF -> return org.nd4j.linalg.api.buffer.DataType.FLOAT16
|
||||||
|
DataType.DT_INT8, DataType.DT_INT8_REF -> return org.nd4j.linalg.api.buffer.DataType.INT8
|
||||||
DataType.DT_INT16, DataType.DT_INT16_REF -> return org.nd4j.linalg.api.buffer.DataType.INT16
|
DataType.DT_INT16, DataType.DT_INT16_REF -> return org.nd4j.linalg.api.buffer.DataType.INT16
|
||||||
DataType.DT_INT32, DataType.DT_INT32_REF -> return org.nd4j.linalg.api.buffer.DataType.INT32
|
DataType.DT_INT32, DataType.DT_INT32_REF -> return org.nd4j.linalg.api.buffer.DataType.INT32
|
||||||
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return org.nd4j.linalg.api.buffer.DataType.DOUBLE
|
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return org.nd4j.linalg.api.buffer.DataType.DOUBLE
|
||||||
|
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT8
|
||||||
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT16
|
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT16
|
||||||
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT32
|
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT32
|
||||||
DataType.DT_UINT64, DataType.DT_UINT64_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT64
|
DataType.DT_UINT64, DataType.DT_UINT64_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT64
|
||||||
|
|
|
@ -134,7 +134,6 @@ class TensorflowIRGraph(graphDef: GraphDef, opDef: OpList
|
||||||
val node = nodeByName(varName)
|
val node = nodeByName(varName)
|
||||||
val attrMap = node.attrMap
|
val attrMap = node.attrMap
|
||||||
if(!attrMap.containsKey("dtype")) {
|
if(!attrMap.containsKey("dtype")) {
|
||||||
|
|
||||||
val retSet = attrMap.values.filter { attrValue -> attrValue.type != DataType.DT_INVALID }
|
val retSet = attrMap.values.filter { attrValue -> attrValue.type != DataType.DT_INVALID }
|
||||||
if(retSet.isEmpty()) {
|
if(retSet.isEmpty()) {
|
||||||
return TensorflowIRDataType(DataType.DT_INVALID)
|
return TensorflowIRDataType(DataType.DT_INVALID)
|
||||||
|
|
|
@ -1172,6 +1172,18 @@ mappings {
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "Size"
|
inputFrameworkOpName: "Size"
|
||||||
}
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "valuemapping"
|
||||||
|
functionName: "valuemapping"
|
||||||
|
inputDataTypeName: "out_type"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "out_type"
|
||||||
|
}
|
||||||
|
ruleType: "attribute"
|
||||||
|
inputFrameworkOpName: "Size"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
@ -2615,6 +2627,35 @@ mappings {
|
||||||
inputFrameworkOpName: "Polygamma"
|
inputFrameworkOpName: "Polygamma"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
mappings {
|
||||||
|
frameworkName: "tensorflow"
|
||||||
|
opName: "matrix_band_part"
|
||||||
|
inputFrameworkOpName: "MatrixBandPart"
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraymapping"
|
||||||
|
functionName: "ndarraymapping"
|
||||||
|
inputTensorName: "input"
|
||||||
|
inputTensorName: "num_lower"
|
||||||
|
inputTensorName: "num_upper"
|
||||||
|
outputTensorName: "input"
|
||||||
|
outputTensorName: "minLowerT"
|
||||||
|
outputTensorName: "maxUpperT"
|
||||||
|
inputToOutput {
|
||||||
|
key: "input"
|
||||||
|
value: "input"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "minLowerT"
|
||||||
|
value: "num_lower"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "maxUpperT"
|
||||||
|
value: "num_upper"
|
||||||
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "MatrixBandPart"
|
||||||
|
}
|
||||||
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "equals"
|
opName: "equals"
|
||||||
|
@ -6427,8 +6468,9 @@ mappings {
|
||||||
ruleName: "valuemapping"
|
ruleName: "valuemapping"
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputDataTypeName: "type"
|
inputDataTypeName: "type"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "dataType"
|
key: "dtype"
|
||||||
value: "type"
|
value: "type"
|
||||||
}
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
|
@ -6834,23 +6876,11 @@ mappings {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "input"
|
inputTensorName: "input"
|
||||||
inputTensorName: "axis"
|
|
||||||
inputTensorName: "shift"
|
|
||||||
outputTensorName: "input"
|
outputTensorName: "input"
|
||||||
outputTensorName: "dimensions"
|
|
||||||
outputTensorName: "shiftsI"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "input"
|
||||||
value: "input"
|
value: "input"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
|
||||||
key: "dimensions"
|
|
||||||
value: "axis"
|
|
||||||
}
|
|
||||||
inputToOutput {
|
|
||||||
key: "shiftsI"
|
|
||||||
value: "shift"
|
|
||||||
}
|
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "Roll"
|
inputFrameworkOpName: "Roll"
|
||||||
}
|
}
|
||||||
|
@ -6862,6 +6892,10 @@ mappings {
|
||||||
key: "shift"
|
key: "shift"
|
||||||
value: "shift"
|
value: "shift"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dimensions"
|
||||||
|
value: "axis"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "Roll"
|
inputFrameworkOpName: "Roll"
|
||||||
}
|
}
|
||||||
|
@ -7251,6 +7285,7 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "start"
|
||||||
outputDoubleName: "stop"
|
outputDoubleName: "stop"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "start"
|
key: "start"
|
||||||
|
@ -9237,6 +9272,8 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "on"
|
||||||
|
outputDoubleName: "off"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "on"
|
key: "on"
|
||||||
value: "on_value"
|
value: "on_value"
|
||||||
|
@ -10592,6 +10629,8 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "min"
|
||||||
|
outputDoubleName: "max"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "min"
|
key: "min"
|
||||||
value: "minval"
|
value: "minval"
|
||||||
|
@ -12082,10 +12121,10 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "datatypetoint"
|
ruleName: "datatypetoint"
|
||||||
functionName: "datatypetoint"
|
functionName: "datatypetoint"
|
||||||
|
outputIntName: "outputDataType"
|
||||||
inputDataTypeName: "T"
|
inputDataTypeName: "T"
|
||||||
outputDataTypeName: "dtype"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "dtype"
|
key: "outputDataType"
|
||||||
value: "T"
|
value: "T"
|
||||||
}
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
|
@ -12352,6 +12391,18 @@ mappings {
|
||||||
}
|
}
|
||||||
inputFrameworkOpName: "MaxPoolWithArgmax"
|
inputFrameworkOpName: "MaxPoolWithArgmax"
|
||||||
}
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "valuemapping"
|
||||||
|
functionName: "valuemapping"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
|
ruleType: "attribute"
|
||||||
|
inputFrameworkOpName: "MaxPoolWithArgmax"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
@ -13428,6 +13479,9 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "from"
|
||||||
|
outputDoubleName: "to"
|
||||||
|
outputDoubleName: "step"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "from"
|
key: "from"
|
||||||
value: "start"
|
value: "start"
|
||||||
|
@ -14778,6 +14832,7 @@ mappings {
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
outputIntName: "maxOutputSize"
|
outputIntName: "maxOutputSize"
|
||||||
outputDoubleName: "overlapThreshold"
|
outputDoubleName: "overlapThreshold"
|
||||||
|
outputDoubleName: "scoreThreshold"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "maxOutputSize"
|
key: "maxOutputSize"
|
||||||
value: "max_output_size"
|
value: "max_output_size"
|
||||||
|
@ -15157,7 +15212,7 @@ mappings {
|
||||||
functionName: "stringequals"
|
functionName: "stringequals"
|
||||||
inputStringAttrName: "padding"
|
inputStringAttrName: "padding"
|
||||||
inputStringAttrName: "padding"
|
inputStringAttrName: "padding"
|
||||||
outputBooleanName: "isSameMode"
|
outputIntName: "isSameMode"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "isSameMode"
|
key: "isSameMode"
|
||||||
value: "padding"
|
value: "padding"
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper
|
||||||
import org.nd4j.ir.OpNamespace
|
import org.nd4j.ir.OpNamespace
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
import org.nd4j.linalg.api.ops.DynamicCustomOp
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Roll
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
|
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
@ -46,6 +47,7 @@ import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFramework
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
|
||||||
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
|
||||||
|
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
|
||||||
import org.nd4j.shade.protobuf.ByteString
|
import org.nd4j.shade.protobuf.ByteString
|
||||||
import org.nd4j.shade.protobuf.TextFormat
|
import org.nd4j.shade.protobuf.TextFormat
|
||||||
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
|
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
|
||||||
|
@ -96,6 +98,14 @@ class TestTensorflowIR {
|
||||||
val output = graph.outputAll(inputMap)
|
val output = graph.outputAll(inputMap)
|
||||||
val output2 = importedGraph.outputAll(inputMap)
|
val output2 = importedGraph.outputAll(inputMap)
|
||||||
|
|
||||||
|
|
||||||
|
val matrix =
|
||||||
|
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
|
||||||
|
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
|
||||||
|
val outputs = Nd4j.exec(roll2)[0]
|
||||||
|
val tfOutputRoll = tfOutput["Roll"]
|
||||||
|
val nd4jOutput = output["Roll"]
|
||||||
|
|
||||||
//assertEquals(tfOutput.keys,outputList)
|
//assertEquals(tfOutput.keys,outputList)
|
||||||
//assertEquals(tfOutput.keys,output2.keys)
|
//assertEquals(tfOutput.keys,output2.keys)
|
||||||
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
|
||||||
|
@ -116,7 +126,7 @@ class TestTensorflowIR {
|
||||||
|
|
||||||
println(notEquals)
|
println(notEquals)
|
||||||
|
|
||||||
// assertEquals(output,output2)
|
// assertEquals(output,output2)
|
||||||
//assertEquals(tfOutput,output)
|
//assertEquals(tfOutput,output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1172,6 +1172,18 @@ mappings {
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "Size"
|
inputFrameworkOpName: "Size"
|
||||||
}
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "valuemapping"
|
||||||
|
functionName: "valuemapping"
|
||||||
|
inputDataTypeName: "out_type"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "out_type"
|
||||||
|
}
|
||||||
|
ruleType: "attribute"
|
||||||
|
inputFrameworkOpName: "Size"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
@ -2615,6 +2627,35 @@ mappings {
|
||||||
inputFrameworkOpName: "Polygamma"
|
inputFrameworkOpName: "Polygamma"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
mappings {
|
||||||
|
frameworkName: "tensorflow"
|
||||||
|
opName: "matrix_band_part"
|
||||||
|
inputFrameworkOpName: "MatrixBandPart"
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraymapping"
|
||||||
|
functionName: "ndarraymapping"
|
||||||
|
inputTensorName: "input"
|
||||||
|
inputTensorName: "num_lower"
|
||||||
|
inputTensorName: "num_upper"
|
||||||
|
outputTensorName: "input"
|
||||||
|
outputTensorName: "minLowerT"
|
||||||
|
outputTensorName: "maxUpperT"
|
||||||
|
inputToOutput {
|
||||||
|
key: "input"
|
||||||
|
value: "input"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "minLowerT"
|
||||||
|
value: "num_lower"
|
||||||
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "maxUpperT"
|
||||||
|
value: "num_upper"
|
||||||
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "MatrixBandPart"
|
||||||
|
}
|
||||||
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "equals"
|
opName: "equals"
|
||||||
|
@ -6427,8 +6468,9 @@ mappings {
|
||||||
ruleName: "valuemapping"
|
ruleName: "valuemapping"
|
||||||
functionName: "valuemapping"
|
functionName: "valuemapping"
|
||||||
inputDataTypeName: "type"
|
inputDataTypeName: "type"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "dataType"
|
key: "dtype"
|
||||||
value: "type"
|
value: "type"
|
||||||
}
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
|
@ -6834,23 +6876,11 @@ mappings {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "input"
|
inputTensorName: "input"
|
||||||
inputTensorName: "axis"
|
|
||||||
inputTensorName: "shift"
|
|
||||||
outputTensorName: "input"
|
outputTensorName: "input"
|
||||||
outputTensorName: "dimensions"
|
|
||||||
outputTensorName: "shiftsI"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "input"
|
||||||
value: "input"
|
value: "input"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
|
||||||
key: "dimensions"
|
|
||||||
value: "axis"
|
|
||||||
}
|
|
||||||
inputToOutput {
|
|
||||||
key: "shiftsI"
|
|
||||||
value: "shift"
|
|
||||||
}
|
|
||||||
ruleType: "tensor"
|
ruleType: "tensor"
|
||||||
inputFrameworkOpName: "Roll"
|
inputFrameworkOpName: "Roll"
|
||||||
}
|
}
|
||||||
|
@ -6862,6 +6892,10 @@ mappings {
|
||||||
key: "shift"
|
key: "shift"
|
||||||
value: "shift"
|
value: "shift"
|
||||||
}
|
}
|
||||||
|
inputToOutput {
|
||||||
|
key: "dimensions"
|
||||||
|
value: "axis"
|
||||||
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "Roll"
|
inputFrameworkOpName: "Roll"
|
||||||
}
|
}
|
||||||
|
@ -7251,6 +7285,7 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "start"
|
||||||
outputDoubleName: "stop"
|
outputDoubleName: "stop"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "start"
|
key: "start"
|
||||||
|
@ -9237,6 +9272,8 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "on"
|
||||||
|
outputDoubleName: "off"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "on"
|
key: "on"
|
||||||
value: "on_value"
|
value: "on_value"
|
||||||
|
@ -10592,6 +10629,8 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "min"
|
||||||
|
outputDoubleName: "max"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "min"
|
key: "min"
|
||||||
value: "minval"
|
value: "minval"
|
||||||
|
@ -12082,10 +12121,10 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "datatypetoint"
|
ruleName: "datatypetoint"
|
||||||
functionName: "datatypetoint"
|
functionName: "datatypetoint"
|
||||||
|
outputIntName: "outputDataType"
|
||||||
inputDataTypeName: "T"
|
inputDataTypeName: "T"
|
||||||
outputDataTypeName: "dtype"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "dtype"
|
key: "outputDataType"
|
||||||
value: "T"
|
value: "T"
|
||||||
}
|
}
|
||||||
ruleType: "attribute"
|
ruleType: "attribute"
|
||||||
|
@ -12352,6 +12391,18 @@ mappings {
|
||||||
}
|
}
|
||||||
inputFrameworkOpName: "MaxPoolWithArgmax"
|
inputFrameworkOpName: "MaxPoolWithArgmax"
|
||||||
}
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "valuemapping"
|
||||||
|
functionName: "valuemapping"
|
||||||
|
inputDataTypeName: "T"
|
||||||
|
outputDataTypeName: "dtype"
|
||||||
|
inputToOutput {
|
||||||
|
key: "dtype"
|
||||||
|
value: "T"
|
||||||
|
}
|
||||||
|
ruleType: "attribute"
|
||||||
|
inputFrameworkOpName: "MaxPoolWithArgmax"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
|
@ -13428,6 +13479,9 @@ mappings {
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarrayinputtonumericalattribute"
|
ruleName: "ndarrayinputtonumericalattribute"
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
|
outputDoubleName: "from"
|
||||||
|
outputDoubleName: "to"
|
||||||
|
outputDoubleName: "step"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "from"
|
key: "from"
|
||||||
value: "start"
|
value: "start"
|
||||||
|
@ -14778,6 +14832,7 @@ mappings {
|
||||||
functionName: "ndarrayinputtonumericalattribute"
|
functionName: "ndarrayinputtonumericalattribute"
|
||||||
outputIntName: "maxOutputSize"
|
outputIntName: "maxOutputSize"
|
||||||
outputDoubleName: "overlapThreshold"
|
outputDoubleName: "overlapThreshold"
|
||||||
|
outputDoubleName: "scoreThreshold"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "maxOutputSize"
|
key: "maxOutputSize"
|
||||||
value: "max_output_size"
|
value: "max_output_size"
|
||||||
|
@ -15157,7 +15212,7 @@ mappings {
|
||||||
functionName: "stringequals"
|
functionName: "stringequals"
|
||||||
inputStringAttrName: "padding"
|
inputStringAttrName: "padding"
|
||||||
inputStringAttrName: "padding"
|
inputStringAttrName: "padding"
|
||||||
outputBooleanName: "isSameMode"
|
outputIntName: "isSameMode"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "isSameMode"
|
key: "isSameMode"
|
||||||
value: "padding"
|
value: "padding"
|
||||||
|
|
|
@ -5,26 +5,26 @@ node {
|
||||||
key: "value"
|
key: "value"
|
||||||
value {
|
value {
|
||||||
tensor {
|
tensor {
|
||||||
dtype: DT_FLOAT
|
dtype: DT_DOUBLE
|
||||||
tensor_shape {
|
tensor_shape {
|
||||||
dim {
|
dim {
|
||||||
size: 2
|
size: 4
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
size: 3
|
size: 5
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
size: 3
|
size: 4
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tensor_content: "~^G?L\033M?\236p9?\220ol>\356%:?X\2708><q\001?b|d?\224\316\v?\314al?P@\257=,5K?\326\271(?\3566\016?`u#>0\024\236>\240{\036>\240h\360>"
|
tensor_content: "0m4\347\376y\315?\344\033;\004\236p\351?\026..t\357%\352? \306G>\322\023\247?\230\303\330E)\235\327?,5\313k\v\350\345?\334m\034\311\255s\321?0\024\236\222\260\272\321?@\321\340\331\240{\316?v|\234\234\223o\356?\244&\\\214\314W\332?\270\376rL\357\224\331?\030\216\353R\253\023\332?*A<\365%\024\344?h\f\326CH\250\351?0\320Y{\256\\\261?\366\226\304\200\016\350\343?@cr\031B\026\323?\300Q\006<\213\200\303?\f\377 \310\035\271\320?\300\325\277\177]\302\347?\270\337\237\302\321P\343?\256L\224~H]\351?0\243.\031\266\256\342?\200B\250\257\316j\301?\304\372\266\312\352\225\332?D\320\255\371\vB\327?\364\025&\270p\360\327?H:\241v\n\272\312?0-} \201!\323?l2\v\247;|\331?\320r}\'\v\372\341?\006u\364aS@\347?P6<\211\270\337\273?\024\340\006`\267\262\322?t\222JV\310\'\325?\264/\177\232\"]\322?\242\037L\006\215=\345?\270\376\302\274\332\310\323?R\2100\375\212\314\355?\300;\026W\321\210\230?\260?t#\314\203\322?\366=D:{\005\342?p_\v\2335\315\356?\344U\370\034\372\317\332? l1N\344\252\330?\354\327\003\223\206\215\321?^C\326M\353\234\345?\326\333u\025\2449\354?\264\000\334Z\177\326\353?\244\341\253\326!\272\345?\320\336\312`\255\005\311?\244u\220o\266\033\324?V\b\201\267\276\271\351?$\253\324_o3\356?\264:\260\357i\003\335?\300[\'7L8\256?02UU\205\214\265?\240\255\276\263r+\257? \303\207\022\3446\334?\220\032\360l\364(\324?\030\374\036.\217W\355?\340!;ay\034\257?\312\255)\371\227\333\346?F\233`e\300\335\343?\264>\261\354\324\345\331?\212v\025\026\265o\342?|\036\036F\364}\330?\034\203\363\362\364\204\322?\000:\260e\"\372\230?F\316\345\330\2577\343?\356uN6<b\355?6\216\263\234\243N\351?\306b\204\305\260\200\342?\210\213\262X/J\331?(\352\313\203=\217\312?\310\3327\231\362\a\345?\210\234<\035}\241\313?\224\'0\201\363V\344?\240\260-3\3655\345?"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
value {
|
value {
|
||||||
type: DT_FLOAT
|
type: DT_DOUBLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,12 @@ node {
|
||||||
name: "in_0/read"
|
name: "in_0/read"
|
||||||
op: "Identity"
|
op: "Identity"
|
||||||
input: "in_0"
|
input: "in_0"
|
||||||
|
attr {
|
||||||
|
key: "T"
|
||||||
|
value {
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
attr {
|
attr {
|
||||||
key: "_class"
|
key: "_class"
|
||||||
value {
|
value {
|
||||||
|
@ -40,96 +46,73 @@ node {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
name: "Roll/shift"
|
||||||
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "T"
|
key: "dtype"
|
||||||
value {
|
value {
|
||||||
type: DT_FLOAT
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
key: "value"
|
||||||
|
value {
|
||||||
|
tensor {
|
||||||
|
dtype: DT_INT32
|
||||||
|
tensor_shape {
|
||||||
|
}
|
||||||
|
int_val: 2
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name: "eye/ones"
|
name: "Roll/axis"
|
||||||
op: "Const"
|
op: "Const"
|
||||||
attr {
|
attr {
|
||||||
key: "value"
|
key: "value"
|
||||||
value {
|
value {
|
||||||
tensor {
|
tensor {
|
||||||
dtype: DT_FLOAT
|
dtype: DT_INT32
|
||||||
tensor_shape {
|
tensor_shape {
|
||||||
dim {
|
|
||||||
size: 2
|
|
||||||
}
|
|
||||||
dim {
|
|
||||||
size: 3
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
float_val: 1.0
|
int_val: 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
key: "dtype"
|
key: "dtype"
|
||||||
value {
|
value {
|
||||||
type: DT_FLOAT
|
type: DT_INT32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
name: "eye/diag"
|
name: "Roll"
|
||||||
op: "MatrixDiag"
|
op: "Roll"
|
||||||
input: "eye/ones"
|
|
||||||
attr {
|
|
||||||
key: "T"
|
|
||||||
value {
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node {
|
|
||||||
name: "Add"
|
|
||||||
op: "Add"
|
|
||||||
input: "in_0/read"
|
input: "in_0/read"
|
||||||
input: "eye/diag"
|
input: "Roll/shift"
|
||||||
|
input: "Roll/axis"
|
||||||
attr {
|
attr {
|
||||||
key: "T"
|
key: "Taxis"
|
||||||
value {
|
value {
|
||||||
type: DT_FLOAT
|
type: DT_INT32
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node {
|
|
||||||
name: "Svd"
|
|
||||||
op: "Svd"
|
|
||||||
input: "Add"
|
|
||||||
attr {
|
|
||||||
key: "full_matrices"
|
|
||||||
value {
|
|
||||||
b: true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
key: "compute_uv"
|
key: "Tshift"
|
||||||
value {
|
value {
|
||||||
b: false
|
type: DT_INT32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
key: "T"
|
key: "T"
|
||||||
value {
|
value {
|
||||||
type: DT_FLOAT
|
type: DT_DOUBLE
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
node {
|
|
||||||
name: "Abs"
|
|
||||||
op: "Abs"
|
|
||||||
input: "Svd"
|
|
||||||
attr {
|
|
||||||
key: "T"
|
|
||||||
value {
|
|
||||||
type: DT_FLOAT
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
library {
|
library {
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,2 @@
|
||||||
output,output
|
in_0/read,in_0/read
|
||||||
output:1,output
|
Roll,Roll
|
||||||
out0,out0
|
|
||||||
out1,out1
|
|
||||||
|
|
|
@ -1,5 +1,2 @@
|
||||||
in_0/read,in_0/read
|
in_0/read,in_0/read
|
||||||
eye/diag,eye/diag
|
Roll,Roll
|
||||||
Add,Add
|
|
||||||
Svd,Svd
|
|
||||||
Abs,Abs
|
|
||||||
|
|
Loading…
Reference in New Issue