Fix data type and roll

master
agibsonccc 2021-02-07 19:27:41 +09:00
parent 04209693f5
commit 53bfdb9994
28 changed files with 3403 additions and 2919 deletions

View File

@ -51,9 +51,10 @@ namespace ops {
if (shift < 0) { if (shift < 0) {
shift -= input->sizeAt(i) * (shift / inputLen - 1); shift -= input->sizeAt(i) * (shift / inputLen - 1);
} }
else { else if(shift != 0) {
shift %= input->sizeAt(i); shift %= input->sizeAt(i);
} }
shifts[i] = shift; shifts[i] = shift;
} }
@ -64,7 +65,7 @@ namespace ops {
// convert shift to positive value between 1 and inputLen - 1 // convert shift to positive value between 1 and inputLen - 1
shift -= inputLen * (shift / inputLen - 1); shift -= inputLen * (shift / inputLen - 1);
} }
else else if(shift != 0)
// cut shift to value between 1 and inputLen - 1 // cut shift to value between 1 and inputLen - 1
shift %= inputLen; shift %= inputLen;
axes.resize(block.getIArguments()->size() - 1); axes.resize(block.getIArguments()->size() - 1);
@ -87,6 +88,21 @@ namespace ops {
if (block.isInplace()) output = input; if (block.isInplace()) output = input;
shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1); shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
nd4j_debug("Roll: Shift is linear %d Shift is %d, first dimension is %d\n",shiftIsLinear,shifts[0],axes[0]);
bool shiftsSumZero = false;
auto shiftSum = 0;
for (auto& s: shifts) {
shiftSum += s;
nd4j_debug("Roll: Shift is %d\n",s);
}
//all zeros is no op
if(shiftSum < 1) {
nd4j_debug("Roll: No shift needed. Shift total was %d\n",shiftSum);
if(!block.isInplace()) {
output->assign(input);
}
return Status::OK();
}
if (shiftIsLinear) { if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());

View File

@ -73,7 +73,7 @@ namespace helpers {
} }
} }
// stage 3) swap remainer of items. // stage 3) swap remainder of items.
if (remainShift && shiftCount) if (remainShift && shiftCount)
for (int i = actualShift; i < 2 * actualShift; ++i) { for (int i = actualShift; i < 2 * actualShift; ++i) {
auto _e0 = output->e<T>(i); auto _e0 = output->e<T>(i);
@ -95,10 +95,11 @@ namespace helpers {
auto source = output; //input; auto source = output; //input;
for (size_t i = 0; i < axes.size(); i++) { for (size_t i = 0; i < axes.size(); i++) {
int axe = axes[i]; int axe = axes[i];
if (axe == source->rankOf() - 1) {// last dimension // if (axe == source->rankOf() - 1) {// last dimension
ResultSet listOfTensors = source->allTensorsAlongDimension({axe}); ResultSet listOfTensors = source->allTensorsAlongDimension({axe});
ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe}); ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe});
int fullLen = listOfTensors.size(); int fullLen = listOfTensors.size();
nd4j_debug("Roll: fullLen at last dimension is %d\n",fullLen);
int theShift = shifts[i]; int theShift = shifts[i];
if (theShift > 0) { if (theShift > 0) {
theShift %= fullLen; theShift %= fullLen;
@ -109,7 +110,7 @@ namespace helpers {
for (int k = 0; k < fullLen; k++) { for (int k = 0; k < fullLen; k++) {
rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true); rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true);
} }
} /* }
else { else {
std::vector<int> dims(source->rankOf() - axe - 1); std::vector<int> dims(source->rankOf() - axe - 1);
for (size_t i = 0; i < dims.size(); ++i) for (size_t i = 0; i < dims.size(); ++i)
@ -120,6 +121,7 @@ namespace helpers {
// //
int fullLen = listOfTensors.size(); int fullLen = listOfTensors.size();
int sizeAt = input->sizeAt(axe); int sizeAt = input->sizeAt(axe);
nd4j_debug("Roll: fullLen at dimension %d is %d\n",i,fullLen);
int theShift = shifts[i]; int theShift = shifts[i];
@ -131,14 +133,14 @@ namespace helpers {
} }
if (theShift) { if (theShift) {
for (int dim = 0; dim < fullLen / sizeAt; ++dim) { for (size_t dim = 0; dim < fullLen / sizeAt; ++dim) {
for (int e = theShift; e < sizeAt - theShift; ++e) { for (size_t e = theShift; e < sizeAt - theShift; ++e) {
auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift); auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift);
auto targetM = listOfOutTensors.at(dim * sizeAt + e); auto targetM = listOfOutTensors.at(dim * sizeAt + e);
sourceM->swapUnsafe(*targetM); sourceM->swapUnsafe(*targetM);
} }
for (int e = 0; e < theShift; ++e) { for (size_t e = 0; e < theShift; ++e) {
int sourceIndex = dim * sizeAt + sizeAt - theShift + e; int sourceIndex = dim * sizeAt + sizeAt - theShift + e;
auto sourceM = listOfTensors.at(sourceIndex); auto sourceM = listOfTensors.at(sourceIndex);
auto targetM = listOfOutTensors.at(dim * sizeAt + e); auto targetM = listOfOutTensors.at(dim * sizeAt + e);
@ -149,7 +151,7 @@ namespace helpers {
} }
} }
// if (!inplace) // if (!inplace)
// source = output; // source = output;*/
} }
} }

View File

@ -34,10 +34,10 @@ public class Roll extends DynamicCustomOp {
public Roll() {} public Roll() {}
public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) { public Roll(@NonNull INDArray input, @NonNull INDArray shifts, @NonNull INDArray axes) {
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank"); Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length"); Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
addInputArgument(input, axes, shifts); addInputArgument(input, shifts, axes);
} }
public Roll(@NonNull INDArray input, int shift) { public Roll(@NonNull INDArray input, int shift) {
@ -49,8 +49,8 @@ public class Roll extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input,shift}); super("", sameDiff, new SDVariable[]{input,shift});
} }
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) { public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift, @NonNull SDVariable axes) {
super("", sameDiff, new SDVariable[]{input,axes,shift}); super("", sameDiff, new SDVariable[]{input,shift,axes});
} }
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) { public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {

View File

@ -284,7 +284,10 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>(); List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0)); result.add(inputDataTypes.get(0));
if(dArguments.isEmpty())
result.add(outputType == null ? DataType.INT : outputType); result.add(outputType == null ? DataType.INT : outputType);
else
result.add(dArguments.get(0));
return result; return result;
} }
} }

View File

@ -36,6 +36,17 @@
<name>nd4j-tests</name> <name>nd4j-tests</name>
<properties> <properties>
<kotlin.version>1.4.30-M1</kotlin.version>
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
<junit.version>4.12</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
<maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<scala.binary.version>2.11</scala.binary.version> <scala.binary.version>2.11</scala.binary.version>
@ -43,7 +54,213 @@
<maven.compiler.testSource>1.8</maven.compiler.testSource> <maven.compiler.testSource>1.8</maven.compiler.testSource>
</properties> </properties>
<build>
<plugins>
<plugin>
<groupId>org.projectlombok</groupId>
<artifactId>lombok-maven-plugin</artifactId>
<version>1.18.12.0</version>
<executions>
<execution>
<id>delombok</id>
<phase>generate-sources</phase>
<goals>
<goal>delombok</goal>
</goals>
<configuration>
<formatPreferences>
<javaLangAsFQN>skip</javaLangAsFQN>
</formatPreferences>
<verbose>true</verbose>
</configuration>
</execution>
<execution>
<id>test-delombok</id>
<phase>generate-test-sources</phase>
<goals>
<goal>testDelombok</goal>
</goals>
<configuration>
<verbose>true</verbose>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals><goal>add-source</goal></goals>
<configuration>
<sources>
<source>src/main/stubs</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<configuration>
<shadedArtifactAttached>true</shadedArtifactAttached>
<createDependencyReducedPom>false</createDependencyReducedPom>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>org/datanucleus/**</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies> <dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<!-- Test Dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-test</artifactId>
<version>${kotlin.version}</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>org.nd4j</groupId> <groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId> <artifactId>samediff-import-tensorflow</artifactId>

View File

@ -68,13 +68,22 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
* all tests will trigger an assumeFalse(..) that indicates * all tests will trigger an assumeFalse(..) that indicates
* the status of the test failing. No tests will run. * the status of the test failing. No tests will run.
*/ */
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(); public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
// "max_pool_with_argmax/int32_int64_padding_SAME",
// "fused_batch_norm/float32_nhwc",
// "max_pool_with_argmax/int64_int64_padding_SAME",
// "fused_batch_norm/float16_nhwc",
"roll/rank3_int32_axis",
"roll/rank3_int32_axis",
"roll/rank2_float32_zeroshift",
"roll/rank3_float64_axis"
);
public static final String[] IGNORE_REGEXES = new String[]{ public static final String[] IGNORE_REGEXES = new String[]{
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance // Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
"bincount/.*", //Invalid test cases. Verified by running graph against actual TF.
"slogdet/.*",
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too //TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod // Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod

View File

@ -881,7 +881,7 @@ public class CustomOpsTests extends BaseNd4jTest {
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out); BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
List<LongShapeDescriptor> lsd = op.calculateOutputShape(); List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size()); assertEquals(1, lsd.size());
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape()); assertArrayEquals(new long[]{1,10,2}, lsd.get(0).getShape());
} }
@Test @Test
@ -942,21 +942,21 @@ public class CustomOpsTests extends BaseNd4jTest {
} }
@Test @Test
@Ignore("Failing with results that are close")
public void testFakeQuantAgainstTF_1() { public void testFakeQuantAgainstTF_1() {
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, INDArray x = Nd4j.createFromArray(new double[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
INDArray min = Nd4j.createFromArray(new float[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); INDArray min = Nd4j.createFromArray(new double[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
INDArray max = Nd4j.createFromArray(new float[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); INDArray max = Nd4j.createFromArray(new double[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, INDArray expected = Nd4j.createFromArray(new double[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
INDArray out = Nd4j.createUninitialized(x.shape());
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max); val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
Nd4j.exec(op); INDArray[] output = Nd4j.exec(op);
assertEquals(expected, out); assertEquals(expected, output[0]);
} }
@Test @Test
@ -971,8 +971,7 @@ public class CustomOpsTests extends BaseNd4jTest {
@Test @Test
public void testResizeBilinear1() { public void testResizeBilinear1() {
INDArray x = Nd4j.rand(1, 10,10,4);
INDArray x = Nd4j.rand(1, 2,3,4);
INDArray z = Nd4j.createUninitialized(x.shape()); INDArray z = Nd4j.createUninitialized(x.shape());
boolean align = false; boolean align = false;
val op = new ResizeBilinear(x, z, 10, 10, align, false); val op = new ResizeBilinear(x, z, 10, 10, align, false);
@ -1082,24 +1081,8 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray distance = Nd4j.scalar(0.f); INDArray distance = Nd4j.scalar(0.f);
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance)); Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
// System.out.println(distance);
} }
@Ignore("2019/11/15 AS - https://github.com/eclipse/deeplearning4j/issues/8399")
@Test
public void testCropAndResize() {
INDArray image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1);
INDArray boxes = Nd4j.createFromArray(new float[]{1,2,3,4}).reshape(1,4);
INDArray box_indices = Nd4j.createFromArray(new int[]{1});
INDArray crop_size = Nd4j.createFromArray(new int[]{1,2}).reshape(1,2);
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
INDArray output = Nd4j.create(DataType.FLOAT, 2,2,1,1);
Nd4j.exec(new CropAndResize(image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
output));
}
@Test @Test
public void testLayersDropoutFail() { public void testLayersDropoutFail() {
@ -1338,7 +1321,6 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]); assertEquals(expected, ret[0]);
} }
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453")
@Test @Test
public void testRoll1() { public void testRoll1() {
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f}); INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
@ -1346,6 +1328,10 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray[] ret = Nd4j.exec(op); INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f}); INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
assertEquals(expected, ret[0]); assertEquals(expected, ret[0]);
INDArray matrix = Nd4j.create(new double[]{0.7788,0.8012,0.7244,0.2309,0.7271,0.1804,0.5056,0.8925}).reshape(2,4);
Roll roll2 = new Roll(matrix,Nd4j.scalar(0),Nd4j.scalar(1));
INDArray[] outputs = Nd4j.exec(roll2);
System.out.println(outputs[0]);
} }
@Test @Test

View File

@ -0,0 +1,118 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.custom
import junit.framework.Assert.assertEquals
import org.junit.Ignore
import org.junit.Test
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ops.impl.image.CropAndResize
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.samediff.frameworkimport.tensorflow.*
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
class CustomOpTensorflowInteropTests {
@Test
@Ignore("Tensorflow expects different shape")
fun testCropAndResize() {
val image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1)
val boxes = Nd4j.createFromArray(*floatArrayOf(1f, 2f, 3f, 4f)).reshape(1, 4)
val box_indices = Nd4j.createFromArray(*intArrayOf(0))
val crop_size = Nd4j.createFromArray(*intArrayOf(1, 2)).reshape( 2)
val imageNode = NodeDef {
op = "Placeholder"
name = "image"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val boxesNode = NodeDef {
op = "Placeholder"
name = "boxes"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val boxIndicesNode = NodeDef {
op = "Placeholder"
name = "boxIndices"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_INT32
})
}
val cropSizesNode = NodeDef {
op = "Placeholder"
name = "cropSize"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_INT32
})
}
val opNode = NodeDef {
op = "CropAndResize"
name = "output"
Input("image")
Input("boxes")
Input("boxIndices")
Input("cropSize")
Attribute("extrapolation_value", AttrValue {
f = 0.5f
})
Attribute("T", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val graph = GraphDef {
Node(imageNode)
Node(boxesNode)
Node(boxIndicesNode)
Node(cropSizesNode)
Node(opNode)
}
val importer = TensorflowFrameworkImporter()
val irGraph = TensorflowIRGraph(graph,importer.opDefList,importer.registry)
val runner = TensorflowIRGraphRunner(irGraph,listOf("image","boxes","boxIndices","cropSize"),listOf("output"))
val tfResult = runner.run(mapOf("image" to image,"boxes" to boxes,"boxIndices" to box_indices,"cropSize" to crop_size))
val outputArr = tfResult["output"]
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
val output = Nd4j.create(DataType.FLOAT, 2, 2, 1, 1)
Nd4j.exec(
CropAndResize(
image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
output
)
)
assertEquals(outputArr,output)
}
}

View File

@ -1,5 +1,2 @@
Variable/read,Variable/read in_0/read,in_0/read
Variable_1/read,Variable_1/read Roll,Roll
floordiv/x,floordiv/x
floordiv/y,floordiv/y
floordiv,floordiv

View File

@ -1,18 +1 @@
in_0/read,in_0/read Sum,Sum
while/Enter,while/Enter
while/Enter_1,while/Enter_1
while/Merge,while/Merge
while/Merge_1,while/Merge_1
while/Less,while/Less
while/LoopCond,while/LoopCond
while/Switch,while/Switch
while/Switch:1,while/Switch
while/Switch_1,while/Switch_1
while/Switch_1:1,while/Switch_1
while/Identity,while/Identity
while/Exit,while/Exit
while/Identity_1,while/Identity_1
while/Exit_1,while/Exit_1
while/add,while/add
while/NextIteration_1,while/NextIteration_1
while/NextIteration,while/NextIteration

View File

@ -207,12 +207,6 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
OpMappingRegistry<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE, OpMappingRegistry<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE,
DATA_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE>): SameDiff { DATA_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE>): SameDiff {
/*
First, build an in-memory representation of the graph that allows us to build the graph incrementally
If we can build the graph incrementally, we can make sure that the added variables are set up with the correct
datatype and (once implemented) greedy shape inference
*/
/* /*
First, build an in-memory representation of the graph that allows us to build the graph incrementally First, build an in-memory representation of the graph that allows us to build the graph incrementally
@ -442,19 +436,34 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
val op = SameDiffOp.builder() val op = SameDiffOp.builder()
.name(name) .name(name)
.op(df) .op(df)
.inputsToOp(inNames) //.outputsOfOp(outNames) //We'll set this later
.controlDeps(controlDeps) .controlDeps(controlDeps)
.build() .build()
//take only up to the inputs that are specified in the node/
//this is for cases where node inputs is > intended number for ops
//a common example is when ops convert input ndarrays to integers or float inputs
val numInputsToTake = importInfo[name]!!.second.argDescriptorList.filter { input -> input.argType == OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR }
.size
op.inputsToOp = inNames.subList(0,numInputsToTake)
//add nodes/other pre processing in order for this node to work //add nodes/other pre processing in order for this node to work
var addToGraph = true
sd.ops[name] = op sd.ops[name] = op
defaultRunner.initAttributes(df, sd, importInfo[name]!!) //clear out inputs for variables as well to reflect the actual graph structure
if(numInputsToTake < numInputs) {
for(i in numInputsToTake until numInputs) {
if(sd.hasVariable(nd.inputAt(i))) {
val currInputVar = sd.variables[nd.inputAt(i)]!!
currInputVar.inputsForOp.remove(op.name)
}
}
}
//cache attributes just in case we have any rules so we don't create the rules more than once //cache attributes just in case we have any rules so we don't create the rules more than once
val attributes = mappingContext.nodeAttributesAsMap() val attributes = mappingContext.nodeAttributesAsMap()
mappingContext.relevantPrehookRules().forEach { rule -> mappingContext.relevantPrehookRules().forEach { rule ->
rule.preProcess(op, sd,attributes) rule.preProcess(op, sd,attributes)
} }
defaultRunner.initAttributes(df, sd, importInfo[name]!!)
//add nodes/other post processing in order for this node to work //add nodes/other post processing in order for this node to work

View File

@ -33,6 +33,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
Onnx.TensorProto.DataType.UINT64 -> return IRDataTypeValue.DT_UINT64 Onnx.TensorProto.DataType.UINT64 -> return IRDataTypeValue.DT_UINT64
Onnx.TensorProto.DataType.UINT32 -> return IRDataTypeValue.DT_UINT32 Onnx.TensorProto.DataType.UINT32 -> return IRDataTypeValue.DT_UINT32
Onnx.TensorProto.DataType.UINT16 -> return IRDataTypeValue.DT_UINT16 Onnx.TensorProto.DataType.UINT16 -> return IRDataTypeValue.DT_UINT16
Onnx.TensorProto.DataType.INT8 -> return IRDataTypeValue.DT_INT8
Onnx.TensorProto.DataType.UINT8 -> return IRDataTypeValue.DT_UINT8
Onnx.TensorProto.DataType.FLOAT16 -> return IRDataTypeValue.DT_HALF Onnx.TensorProto.DataType.FLOAT16 -> return IRDataTypeValue.DT_HALF
Onnx.TensorProto.DataType.STRING -> return IRDataTypeValue.DT_STRING Onnx.TensorProto.DataType.STRING -> return IRDataTypeValue.DT_STRING
Onnx.TensorProto.DataType.FLOAT -> return IRDataTypeValue.DT_FLOAT Onnx.TensorProto.DataType.FLOAT -> return IRDataTypeValue.DT_FLOAT
@ -60,9 +62,11 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
override fun nd4jDataType(): DataType { override fun nd4jDataType(): DataType {
when(this.dataType) { when(this.dataType) {
Onnx.TensorProto.DataType.UINT64 -> return return DataType.INT64 Onnx.TensorProto.DataType.INT8 -> return DataType.INT8
Onnx.TensorProto.DataType.UINT32 -> return return DataType.INT32 Onnx.TensorProto.DataType.UINT8 -> return DataType.UINT8
Onnx.TensorProto.DataType.UINT16 -> return return DataType.INT16 Onnx.TensorProto.DataType.UINT64 -> return DataType.UINT64
Onnx.TensorProto.DataType.UINT32 -> return DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return DataType.INT16
Onnx.TensorProto.DataType.FLOAT16 -> return return DataType.FLOAT16 Onnx.TensorProto.DataType.FLOAT16 -> return return DataType.FLOAT16
Onnx.TensorProto.DataType.STRING -> return return DataType.UTF8 Onnx.TensorProto.DataType.STRING -> return return DataType.UTF8
Onnx.TensorProto.DataType.FLOAT -> return return DataType.FLOAT Onnx.TensorProto.DataType.FLOAT -> return return DataType.FLOAT
@ -81,8 +85,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
override fun nameSpaceDataType(): TensorNamespace.DataType { override fun nameSpaceDataType(): TensorNamespace.DataType {
when(this.dataType) { when(this.dataType) {
Onnx.TensorProto.DataType.UINT64 -> return return TensorNamespace.DataType.INT64 Onnx.TensorProto.DataType.UINT64 -> return return TensorNamespace.DataType.INT64
Onnx.TensorProto.DataType.UINT32 -> return return TensorNamespace.DataType.INT32 Onnx.TensorProto.DataType.UINT32 -> return TensorNamespace.DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return return TensorNamespace.DataType.INT16 Onnx.TensorProto.DataType.UINT16 -> return TensorNamespace.DataType.INT16
Onnx.TensorProto.DataType.FLOAT16 -> return return TensorNamespace.DataType.FLOAT16 Onnx.TensorProto.DataType.FLOAT16 -> return return TensorNamespace.DataType.FLOAT16
Onnx.TensorProto.DataType.STRING -> return return TensorNamespace.DataType.STRING Onnx.TensorProto.DataType.STRING -> return return TensorNamespace.DataType.STRING
Onnx.TensorProto.DataType.FLOAT -> return TensorNamespace.DataType.FLOAT Onnx.TensorProto.DataType.FLOAT -> return TensorNamespace.DataType.FLOAT

View File

@ -1,5 +1,5 @@
Placeholder,data Const,in_0
Placeholder,partitions Const,Roll/shift
DynamicPartition,output Const,Roll/axis
Identity,out0 Identity,in_0/read
Identity,out1 Roll,Roll

View File

@ -1,7 +1,5 @@
Const,in_0 Const,in_0
Const,eye/ones Const,Roll/shift
Const,Roll/axis
Identity,in_0/read Identity,in_0/read
MatrixDiag,eye/diag Roll,Roll
Add,Add
Svd,Svd
Abs,Abs

View File

@ -1,3 +1,2 @@
DynamicPartition,output Identity,in_0/read
Identity,out0 Roll,Roll
Identity,out1

View File

@ -1,5 +1,2 @@
Identity,in_0/read Identity,in_0/read
MatrixDiag,eye/diag Roll,Roll
Add,Add
Svd,Svd
Abs,Abs

View File

@ -1,5 +1,5 @@
data in_0
partitions Roll/shift
output Roll/axis
out0 in_0/read
out1 Roll

View File

@ -1,7 +1,5 @@
in_0 in_0
eye/ones Roll/shift
Roll/axis
in_0/read in_0/read
eye/diag Roll
Add
Svd
Abs

View File

@ -358,7 +358,7 @@ val bitCast = TensorflowMappingProcess(
opMappingRegistry = tensorflowOpRegistry, opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "Bitcast", inputFrameworkOpName = "Bitcast",
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))), tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))),
attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dataType" to "type"))) attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dtype" to "type")))
) )
val bitwiseAnd = TensorflowMappingProcess( val bitwiseAnd = TensorflowMappingProcess(
@ -1070,7 +1070,7 @@ val fill = TensorflowMappingProcess(
inputFrameworkOpName = "Fill", inputFrameworkOpName = "Fill",
opMappingRegistry = tensorflowOpRegistry, opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(convertNDArrayInputToNumericalAttr(mutableMapOf("value" to "value")), attributeMappingRules = listOf(convertNDArrayInputToNumericalAttr(mutableMapOf("value" to "value")),
dataTypeToInt(mutableMapOf("dtype" to "T")), dataTypeToInt(mutableMapOf("outputDataType" to "T")),
valueMapping(mutableMapOf("dtype" to "T"))), valueMapping(mutableMapOf("dtype" to "T"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("shapeArray" to "dims"))) tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("shapeArray" to "dims")))
) )
@ -1381,6 +1381,7 @@ val maxPoolArgmax = multipleNameMapping(
intConstant(inputName = "extraParam0",constantValue = 0 ,argumentIndex = 9)[0], intConstant(inputName = "extraParam0",constantValue = 0 ,argumentIndex = 9)[0],
intConstant(inputName = "isNHWC",argumentIndex = 10,constantValue = 1 )[0], intConstant(inputName = "isNHWC",argumentIndex = 10,constantValue = 1 )[0],
intConstant(inputName = "sameMode",argumentIndex = 8,constantValue = 8 )[0], intConstant(inputName = "sameMode",argumentIndex = 8,constantValue = 8 )[0],
valueMapping(mutableMapOf("dtype" to "T"))
) )
,tensorflowOpRegistry = tensorflowOpRegistry ,tensorflowOpRegistry = tensorflowOpRegistry
) )
@ -1455,6 +1456,13 @@ val mirrorPadding = mapTensorNamesWithOp(inputFrameworkOpName = "MirrorPad",opNa
* v1 and 2 which do not have a scoreThreshold to map. V3 does. * v1 and 2 which do not have a scoreThreshold to map. V3 does.
*/ */
val matrixBandPart = mapTensorNamesWithOp(inputFrameworkOpName = "MatrixBandPart",opName = "matrix_band_part",
tensorNames = mutableMapOf("input" to "input","minLowerT" to "num_lower",
"maxUpperT" to "num_upper"),
attributeMappingRules = listOf()
,tensorflowOpRegistry = tensorflowOpRegistry)
val nonMaxSuppressionV1 = multipleNameMapping(inputFrameworkOpNames = listOf("NonMaxSuppression"), val nonMaxSuppressionV1 = multipleNameMapping(inputFrameworkOpNames = listOf("NonMaxSuppression"),
opName = "non_max_suppression", opName = "non_max_suppression",
tensorNames = mutableMapOf("boxes" to "boxes","scales" to "scores", tensorNames = mutableMapOf("boxes" to "boxes","scales" to "scores",
@ -1654,7 +1662,8 @@ val randomCrop = mapTensorNamesWithOp(inputFrameworkOpName = "RandomCrop",opName
attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed"))) attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed")))
,tensorflowOpRegistry = tensorflowOpRegistry) ,tensorflowOpRegistry = tensorflowOpRegistry)
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf(), val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",
tensorNames = mutableMapOf(),
attributeMappingRules = listOf() attributeMappingRules = listOf()
,tensorflowOpRegistry = tensorflowOpRegistry) ,tensorflowOpRegistry = tensorflowOpRegistry)
@ -1823,8 +1832,8 @@ val reverseSequence = multipleNameMapping(inputFrameworkOpNames = listOf("Revers
,tensorflowOpRegistry = tensorflowOpRegistry) ,tensorflowOpRegistry = tensorflowOpRegistry)
val roll = multipleNameMapping(inputFrameworkOpNames = listOf("Roll"),opName = "roll", val roll = multipleNameMapping(inputFrameworkOpNames = listOf("Roll"),opName = "roll",
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift"))), attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift","dimensions" to "axis"))),
tensorNames = mutableMapOf("input" to "input","dimensions" to "axis","shiftsI" to "shift") tensorNames = mutableMapOf("input" to "input")
,tensorflowOpRegistry = tensorflowOpRegistry) ,tensorflowOpRegistry = tensorflowOpRegistry)
//TODO: verify usingLocking property, it's not showing up in descriptors //TODO: verify usingLocking property, it's not showing up in descriptors
@ -1941,6 +1950,7 @@ val size = TensorflowMappingProcess(
opMappingRegistry = tensorflowOpRegistry, opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "Size", inputFrameworkOpName = "Size",
opName = "size", opName = "size",
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "out_type"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))) tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input")))
) )

View File

@ -36,6 +36,10 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return IRDataTypeValue.DT_DOUBLE DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return IRDataTypeValue.DT_DOUBLE
DataType.DT_FLOAT, DataType.DT_FLOAT_REF -> return IRDataTypeValue.DT_FLOAT DataType.DT_FLOAT, DataType.DT_FLOAT_REF -> return IRDataTypeValue.DT_FLOAT
DataType.DT_HALF, DataType.DT_HALF_REF -> return IRDataTypeValue.DT_HALF DataType.DT_HALF, DataType.DT_HALF_REF -> return IRDataTypeValue.DT_HALF
DataType.DT_INT8, DataType.DT_INT8_REF -> return IRDataTypeValue.DT_INT8
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return IRDataTypeValue.DT_UINT8
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return IRDataTypeValue.DT_UINT16
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return IRDataTypeValue.DT_UINT32
DataType.DT_INT16, DataType.DT_INT16_REF -> return IRDataTypeValue.DT_INT16 DataType.DT_INT16, DataType.DT_INT16_REF -> return IRDataTypeValue.DT_INT16
DataType.DT_INT32, DataType.DT_INT32_REF -> return IRDataTypeValue.DT_INT32 DataType.DT_INT32, DataType.DT_INT32_REF -> return IRDataTypeValue.DT_INT32
DataType.DT_INT64, DataType.DT_INT64_REF -> return IRDataTypeValue.DT_INT64 DataType.DT_INT64, DataType.DT_INT64_REF -> return IRDataTypeValue.DT_INT64
@ -70,9 +74,11 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
DataType.DT_BFLOAT16, DataType.DT_BFLOAT16_REF -> return org.nd4j.linalg.api.buffer.DataType.BFLOAT16 DataType.DT_BFLOAT16, DataType.DT_BFLOAT16_REF -> return org.nd4j.linalg.api.buffer.DataType.BFLOAT16
DataType.DT_INT64, DataType.DT_INT64_REF -> return org.nd4j.linalg.api.buffer.DataType.INT64 DataType.DT_INT64, DataType.DT_INT64_REF -> return org.nd4j.linalg.api.buffer.DataType.INT64
DataType.DT_HALF, DataType.DT_HALF_REF -> return org.nd4j.linalg.api.buffer.DataType.FLOAT16 DataType.DT_HALF, DataType.DT_HALF_REF -> return org.nd4j.linalg.api.buffer.DataType.FLOAT16
DataType.DT_INT8, DataType.DT_INT8_REF -> return org.nd4j.linalg.api.buffer.DataType.INT8
DataType.DT_INT16, DataType.DT_INT16_REF -> return org.nd4j.linalg.api.buffer.DataType.INT16 DataType.DT_INT16, DataType.DT_INT16_REF -> return org.nd4j.linalg.api.buffer.DataType.INT16
DataType.DT_INT32, DataType.DT_INT32_REF -> return org.nd4j.linalg.api.buffer.DataType.INT32 DataType.DT_INT32, DataType.DT_INT32_REF -> return org.nd4j.linalg.api.buffer.DataType.INT32
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return org.nd4j.linalg.api.buffer.DataType.DOUBLE DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return org.nd4j.linalg.api.buffer.DataType.DOUBLE
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT8
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT16 DataType.DT_UINT16, DataType.DT_UINT16_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT16
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT32 DataType.DT_UINT32, DataType.DT_UINT32_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT32
DataType.DT_UINT64, DataType.DT_UINT64_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT64 DataType.DT_UINT64, DataType.DT_UINT64_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT64

View File

@ -134,7 +134,6 @@ class TensorflowIRGraph(graphDef: GraphDef, opDef: OpList
val node = nodeByName(varName) val node = nodeByName(varName)
val attrMap = node.attrMap val attrMap = node.attrMap
if(!attrMap.containsKey("dtype")) { if(!attrMap.containsKey("dtype")) {
val retSet = attrMap.values.filter { attrValue -> attrValue.type != DataType.DT_INVALID } val retSet = attrMap.values.filter { attrValue -> attrValue.type != DataType.DT_INVALID }
if(retSet.isEmpty()) { if(retSet.isEmpty()) {
return TensorflowIRDataType(DataType.DT_INVALID) return TensorflowIRDataType(DataType.DT_INVALID)

View File

@ -1172,6 +1172,18 @@ mappings {
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "Size" inputFrameworkOpName: "Size"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "out_type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "out_type"
}
ruleType: "attribute"
inputFrameworkOpName: "Size"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
@ -2615,6 +2627,35 @@ mappings {
inputFrameworkOpName: "Polygamma" inputFrameworkOpName: "Polygamma"
} }
} }
mappings {
frameworkName: "tensorflow"
opName: "matrix_band_part"
inputFrameworkOpName: "MatrixBandPart"
rule {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "num_lower"
inputTensorName: "num_upper"
outputTensorName: "input"
outputTensorName: "minLowerT"
outputTensorName: "maxUpperT"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "minLowerT"
value: "num_lower"
}
inputToOutput {
key: "maxUpperT"
value: "num_upper"
}
ruleType: "tensor"
inputFrameworkOpName: "MatrixBandPart"
}
}
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "equals" opName: "equals"
@ -6427,8 +6468,9 @@ mappings {
ruleName: "valuemapping" ruleName: "valuemapping"
functionName: "valuemapping" functionName: "valuemapping"
inputDataTypeName: "type" inputDataTypeName: "type"
outputDataTypeName: "dtype"
inputToOutput { inputToOutput {
key: "dataType" key: "dtype"
value: "type" value: "type"
} }
ruleType: "attribute" ruleType: "attribute"
@ -6834,23 +6876,11 @@ mappings {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "input" inputTensorName: "input"
inputTensorName: "axis"
inputTensorName: "shift"
outputTensorName: "input" outputTensorName: "input"
outputTensorName: "dimensions"
outputTensorName: "shiftsI"
inputToOutput { inputToOutput {
key: "input" key: "input"
value: "input" value: "input"
} }
inputToOutput {
key: "dimensions"
value: "axis"
}
inputToOutput {
key: "shiftsI"
value: "shift"
}
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "Roll" inputFrameworkOpName: "Roll"
} }
@ -6862,6 +6892,10 @@ mappings {
key: "shift" key: "shift"
value: "shift" value: "shift"
} }
inputToOutput {
key: "dimensions"
value: "axis"
}
ruleType: "attribute" ruleType: "attribute"
inputFrameworkOpName: "Roll" inputFrameworkOpName: "Roll"
} }
@ -7251,6 +7285,7 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "start"
outputDoubleName: "stop" outputDoubleName: "stop"
inputToOutput { inputToOutput {
key: "start" key: "start"
@ -9237,6 +9272,8 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "on"
outputDoubleName: "off"
inputToOutput { inputToOutput {
key: "on" key: "on"
value: "on_value" value: "on_value"
@ -10592,6 +10629,8 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "min"
outputDoubleName: "max"
inputToOutput { inputToOutput {
key: "min" key: "min"
value: "minval" value: "minval"
@ -12082,10 +12121,10 @@ mappings {
rule { rule {
ruleName: "datatypetoint" ruleName: "datatypetoint"
functionName: "datatypetoint" functionName: "datatypetoint"
outputIntName: "outputDataType"
inputDataTypeName: "T" inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput { inputToOutput {
key: "dtype" key: "outputDataType"
value: "T" value: "T"
} }
ruleType: "attribute" ruleType: "attribute"
@ -12352,6 +12391,18 @@ mappings {
} }
inputFrameworkOpName: "MaxPoolWithArgmax" inputFrameworkOpName: "MaxPoolWithArgmax"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "MaxPoolWithArgmax"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
@ -13428,6 +13479,9 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "from"
outputDoubleName: "to"
outputDoubleName: "step"
inputToOutput { inputToOutput {
key: "from" key: "from"
value: "start" value: "start"
@ -14778,6 +14832,7 @@ mappings {
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputIntName: "maxOutputSize" outputIntName: "maxOutputSize"
outputDoubleName: "overlapThreshold" outputDoubleName: "overlapThreshold"
outputDoubleName: "scoreThreshold"
inputToOutput { inputToOutput {
key: "maxOutputSize" key: "maxOutputSize"
value: "max_output_size" value: "max_output_size"
@ -15157,7 +15212,7 @@ mappings {
functionName: "stringequals" functionName: "stringequals"
inputStringAttrName: "padding" inputStringAttrName: "padding"
inputStringAttrName: "padding" inputStringAttrName: "padding"
outputBooleanName: "isSameMode" outputIntName: "isSameMode"
inputToOutput { inputToOutput {
key: "isSameMode" key: "isSameMode"
value: "padding" value: "padding"

View File

@ -32,6 +32,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper
import org.nd4j.ir.OpNamespace import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.DynamicCustomOp import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.custom.Roll
import org.nd4j.linalg.api.ops.impl.transforms.BinCount import org.nd4j.linalg.api.ops.impl.transforms.BinCount
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
@ -46,6 +47,7 @@ import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFramework
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
import org.nd4j.shade.protobuf.ByteString import org.nd4j.shade.protobuf.ByteString
import org.nd4j.shade.protobuf.TextFormat import org.nd4j.shade.protobuf.TextFormat
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
@ -96,6 +98,14 @@ class TestTensorflowIR {
val output = graph.outputAll(inputMap) val output = graph.outputAll(inputMap)
val output2 = importedGraph.outputAll(inputMap) val output2 = importedGraph.outputAll(inputMap)
val matrix =
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
val outputs = Nd4j.exec(roll2)[0]
val tfOutputRoll = tfOutput["Roll"]
val nd4jOutput = output["Roll"]
//assertEquals(tfOutput.keys,outputList) //assertEquals(tfOutput.keys,outputList)
//assertEquals(tfOutput.keys,output2.keys) //assertEquals(tfOutput.keys,output2.keys)
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() } val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }

View File

@ -1172,6 +1172,18 @@ mappings {
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "Size" inputFrameworkOpName: "Size"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "out_type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "out_type"
}
ruleType: "attribute"
inputFrameworkOpName: "Size"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
@ -2615,6 +2627,35 @@ mappings {
inputFrameworkOpName: "Polygamma" inputFrameworkOpName: "Polygamma"
} }
} }
mappings {
frameworkName: "tensorflow"
opName: "matrix_band_part"
inputFrameworkOpName: "MatrixBandPart"
rule {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "num_lower"
inputTensorName: "num_upper"
outputTensorName: "input"
outputTensorName: "minLowerT"
outputTensorName: "maxUpperT"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "minLowerT"
value: "num_lower"
}
inputToOutput {
key: "maxUpperT"
value: "num_upper"
}
ruleType: "tensor"
inputFrameworkOpName: "MatrixBandPart"
}
}
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "equals" opName: "equals"
@ -6427,8 +6468,9 @@ mappings {
ruleName: "valuemapping" ruleName: "valuemapping"
functionName: "valuemapping" functionName: "valuemapping"
inputDataTypeName: "type" inputDataTypeName: "type"
outputDataTypeName: "dtype"
inputToOutput { inputToOutput {
key: "dataType" key: "dtype"
value: "type" value: "type"
} }
ruleType: "attribute" ruleType: "attribute"
@ -6834,23 +6876,11 @@ mappings {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "input" inputTensorName: "input"
inputTensorName: "axis"
inputTensorName: "shift"
outputTensorName: "input" outputTensorName: "input"
outputTensorName: "dimensions"
outputTensorName: "shiftsI"
inputToOutput { inputToOutput {
key: "input" key: "input"
value: "input" value: "input"
} }
inputToOutput {
key: "dimensions"
value: "axis"
}
inputToOutput {
key: "shiftsI"
value: "shift"
}
ruleType: "tensor" ruleType: "tensor"
inputFrameworkOpName: "Roll" inputFrameworkOpName: "Roll"
} }
@ -6862,6 +6892,10 @@ mappings {
key: "shift" key: "shift"
value: "shift" value: "shift"
} }
inputToOutput {
key: "dimensions"
value: "axis"
}
ruleType: "attribute" ruleType: "attribute"
inputFrameworkOpName: "Roll" inputFrameworkOpName: "Roll"
} }
@ -7251,6 +7285,7 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "start"
outputDoubleName: "stop" outputDoubleName: "stop"
inputToOutput { inputToOutput {
key: "start" key: "start"
@ -9237,6 +9272,8 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "on"
outputDoubleName: "off"
inputToOutput { inputToOutput {
key: "on" key: "on"
value: "on_value" value: "on_value"
@ -10592,6 +10629,8 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "min"
outputDoubleName: "max"
inputToOutput { inputToOutput {
key: "min" key: "min"
value: "minval" value: "minval"
@ -12082,10 +12121,10 @@ mappings {
rule { rule {
ruleName: "datatypetoint" ruleName: "datatypetoint"
functionName: "datatypetoint" functionName: "datatypetoint"
outputIntName: "outputDataType"
inputDataTypeName: "T" inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput { inputToOutput {
key: "dtype" key: "outputDataType"
value: "T" value: "T"
} }
ruleType: "attribute" ruleType: "attribute"
@ -12352,6 +12391,18 @@ mappings {
} }
inputFrameworkOpName: "MaxPoolWithArgmax" inputFrameworkOpName: "MaxPoolWithArgmax"
} }
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "MaxPoolWithArgmax"
}
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
@ -13428,6 +13479,9 @@ mappings {
rule { rule {
ruleName: "ndarrayinputtonumericalattribute" ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "from"
outputDoubleName: "to"
outputDoubleName: "step"
inputToOutput { inputToOutput {
key: "from" key: "from"
value: "start" value: "start"
@ -14778,6 +14832,7 @@ mappings {
functionName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute"
outputIntName: "maxOutputSize" outputIntName: "maxOutputSize"
outputDoubleName: "overlapThreshold" outputDoubleName: "overlapThreshold"
outputDoubleName: "scoreThreshold"
inputToOutput { inputToOutput {
key: "maxOutputSize" key: "maxOutputSize"
value: "max_output_size" value: "max_output_size"
@ -15157,7 +15212,7 @@ mappings {
functionName: "stringequals" functionName: "stringequals"
inputStringAttrName: "padding" inputStringAttrName: "padding"
inputStringAttrName: "padding" inputStringAttrName: "padding"
outputBooleanName: "isSameMode" outputIntName: "isSameMode"
inputToOutput { inputToOutput {
key: "isSameMode" key: "isSameMode"
value: "padding" value: "padding"

View File

@ -5,26 +5,26 @@ node {
key: "value" key: "value"
value { value {
tensor { tensor {
dtype: DT_FLOAT dtype: DT_DOUBLE
tensor_shape { tensor_shape {
dim { dim {
size: 2 size: 4
} }
dim { dim {
size: 3 size: 5
} }
dim { dim {
size: 3 size: 4
} }
} }
tensor_content: "~^G?L\033M?\236p9?\220ol>\356%:?X\2708><q\001?b|d?\224\316\v?\314al?P@\257=,5K?\326\271(?\3566\016?`u#>0\024\236>\240{\036>\240h\360>" tensor_content: "0m4\347\376y\315?\344\033;\004\236p\351?\026..t\357%\352? \306G>\322\023\247?\230\303\330E)\235\327?,5\313k\v\350\345?\334m\034\311\255s\321?0\024\236\222\260\272\321?@\321\340\331\240{\316?v|\234\234\223o\356?\244&\\\214\314W\332?\270\376rL\357\224\331?\030\216\353R\253\023\332?*A<\365%\024\344?h\f\326CH\250\351?0\320Y{\256\\\261?\366\226\304\200\016\350\343?@cr\031B\026\323?\300Q\006<\213\200\303?\f\377 \310\035\271\320?\300\325\277\177]\302\347?\270\337\237\302\321P\343?\256L\224~H]\351?0\243.\031\266\256\342?\200B\250\257\316j\301?\304\372\266\312\352\225\332?D\320\255\371\vB\327?\364\025&\270p\360\327?H:\241v\n\272\312?0-} \201!\323?l2\v\247;|\331?\320r}\'\v\372\341?\006u\364aS@\347?P6<\211\270\337\273?\024\340\006`\267\262\322?t\222JV\310\'\325?\264/\177\232\"]\322?\242\037L\006\215=\345?\270\376\302\274\332\310\323?R\2100\375\212\314\355?\300;\026W\321\210\230?\260?t#\314\203\322?\366=D:{\005\342?p_\v\2335\315\356?\344U\370\034\372\317\332? l1N\344\252\330?\354\327\003\223\206\215\321?^C\326M\353\234\345?\326\333u\025\2449\354?\264\000\334Z\177\326\353?\244\341\253\326!\272\345?\320\336\312`\255\005\311?\244u\220o\266\033\324?V\b\201\267\276\271\351?$\253\324_o3\356?\264:\260\357i\003\335?\300[\'7L8\256?02UU\205\214\265?\240\255\276\263r+\257? \303\207\022\3446\334?\220\032\360l\364(\324?\030\374\036.\217W\355?\340!;ay\034\257?\312\255)\371\227\333\346?F\233`e\300\335\343?\264>\261\354\324\345\331?\212v\025\026\265o\342?|\036\036F\364}\330?\034\203\363\362\364\204\322?\000:\260e\"\372\230?F\316\345\330\2577\343?\356uN6<b\355?6\216\263\234\243N\351?\306b\204\305\260\200\342?\210\213\262X/J\331?(\352\313\203=\217\312?\310\3327\231\362\a\345?\210\234<\035}\241\313?\224\'0\201\363V\344?\240\260-3\3655\345?"
} }
} }
} }
attr { attr {
key: "dtype" key: "dtype"
value { value {
type: DT_FLOAT type: DT_DOUBLE
} }
} }
} }
@ -32,6 +32,12 @@ node {
name: "in_0/read" name: "in_0/read"
op: "Identity" op: "Identity"
input: "in_0" input: "in_0"
attr {
key: "T"
value {
type: DT_DOUBLE
}
}
attr { attr {
key: "_class" key: "_class"
value { value {
@ -40,94 +46,71 @@ node {
} }
} }
} }
}
node {
name: "Roll/shift"
op: "Const"
attr { attr {
key: "T" key: "dtype"
value { value {
type: DT_FLOAT type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
} }
} }
} }
node { node {
name: "eye/ones" name: "Roll/axis"
op: "Const" op: "Const"
attr { attr {
key: "value" key: "value"
value { value {
tensor { tensor {
dtype: DT_FLOAT dtype: DT_INT32
tensor_shape { tensor_shape {
dim {
size: 2
} }
dim { int_val: 1
size: 3
}
}
float_val: 1.0
} }
} }
} }
attr { attr {
key: "dtype" key: "dtype"
value { value {
type: DT_FLOAT type: DT_INT32
} }
} }
} }
node { node {
name: "eye/diag" name: "Roll"
op: "MatrixDiag" op: "Roll"
input: "eye/ones"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
input: "in_0/read" input: "in_0/read"
input: "eye/diag" input: "Roll/shift"
input: "Roll/axis"
attr { attr {
key: "T" key: "Taxis"
value { value {
type: DT_FLOAT type: DT_INT32
}
}
}
node {
name: "Svd"
op: "Svd"
input: "Add"
attr {
key: "full_matrices"
value {
b: true
} }
} }
attr { attr {
key: "compute_uv" key: "Tshift"
value { value {
b: false type: DT_INT32
} }
} }
attr { attr {
key: "T" key: "T"
value { value {
type: DT_FLOAT type: DT_DOUBLE
}
}
}
node {
name: "Abs"
op: "Abs"
input: "Svd"
attr {
key: "T"
value {
type: DT_FLOAT
} }
} }
} }

View File

@ -1,4 +1,2 @@
output,output in_0/read,in_0/read
output:1,output Roll,Roll
out0,out0
out1,out1

View File

@ -1,5 +1,2 @@
in_0/read,in_0/read in_0/read,in_0/read
eye/diag,eye/diag Roll,Roll
Add,Add
Svd,Svd
Abs,Abs