Fix data type and roll

master
agibsonccc 2021-02-07 19:27:41 +09:00
parent 04209693f5
commit 53bfdb9994
28 changed files with 3403 additions and 2919 deletions

View File

@ -51,9 +51,10 @@ namespace ops {
if (shift < 0) {
shift -= input->sizeAt(i) * (shift / inputLen - 1);
}
else {
else if(shift != 0) {
shift %= input->sizeAt(i);
}
shifts[i] = shift;
}
@ -64,7 +65,7 @@ namespace ops {
// convert shift to positive value between 1 and inputLen - 1
shift -= inputLen * (shift / inputLen - 1);
}
else
else if(shift != 0)
// cut shift to value between 1 and inputLen - 1
shift %= inputLen;
axes.resize(block.getIArguments()->size() - 1);
@ -87,6 +88,21 @@ namespace ops {
if (block.isInplace()) output = input;
shiftIsLinear = (axes.size() == 0) || (input->rankOf() == 1);
nd4j_debug("Roll: Shift is linear %d Shift is %d, first dimension is %d\n",shiftIsLinear,shifts[0],axes[0]);
bool shiftsSumZero = false;
auto shiftSum = 0;
for (auto& s: shifts) {
shiftSum += s;
nd4j_debug("Roll: Shift is %d\n",s);
}
//all zeros is no op
if(shiftSum < 1) {
nd4j_debug("Roll: No shift needed. Shift total was %d\n",shiftSum);
if(!block.isInplace()) {
output->assign(input);
}
return Status::OK();
}
if (shiftIsLinear) {
helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace());

View File

@ -73,7 +73,7 @@ namespace helpers {
}
}
// stage 3) swap remainer of items.
// stage 3) swap remainder of items.
if (remainShift && shiftCount)
for (int i = actualShift; i < 2 * actualShift; ++i) {
auto _e0 = output->e<T>(i);
@ -95,10 +95,11 @@ namespace helpers {
auto source = output; //input;
for (size_t i = 0; i < axes.size(); i++) {
int axe = axes[i];
if (axe == source->rankOf() - 1) {// last dimension
// if (axe == source->rankOf() - 1) {// last dimension
ResultSet listOfTensors = source->allTensorsAlongDimension({axe});
ResultSet listOfOutTensors = output->allTensorsAlongDimension({axe});
int fullLen = listOfTensors.size();
nd4j_debug("Roll: fullLen at last dimension is %d\n",fullLen);
int theShift = shifts[i];
if (theShift > 0) {
theShift %= fullLen;
@ -109,7 +110,7 @@ namespace helpers {
for (int k = 0; k < fullLen; k++) {
rollFunctorLinear(context, listOfTensors.at(k), listOfOutTensors.at(k), theShift, true);
}
}
/* }
else {
std::vector<int> dims(source->rankOf() - axe - 1);
for (size_t i = 0; i < dims.size(); ++i)
@ -120,6 +121,7 @@ namespace helpers {
//
int fullLen = listOfTensors.size();
int sizeAt = input->sizeAt(axe);
nd4j_debug("Roll: fullLen at dimension %d is %d\n",i,fullLen);
int theShift = shifts[i];
@ -131,14 +133,14 @@ namespace helpers {
}
if (theShift) {
for (int dim = 0; dim < fullLen / sizeAt; ++dim) {
for (int e = theShift; e < sizeAt - theShift; ++e) {
for (size_t dim = 0; dim < fullLen / sizeAt; ++dim) {
for (size_t e = theShift; e < sizeAt - theShift; ++e) {
auto sourceM = listOfTensors.at(dim * sizeAt + e - theShift);
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
sourceM->swapUnsafe(*targetM);
}
for (int e = 0; e < theShift; ++e) {
for (size_t e = 0; e < theShift; ++e) {
int sourceIndex = dim * sizeAt + sizeAt - theShift + e;
auto sourceM = listOfTensors.at(sourceIndex);
auto targetM = listOfOutTensors.at(dim * sizeAt + e);
@ -149,7 +151,7 @@ namespace helpers {
}
}
// if (!inplace)
// source = output;
// source = output;*/
}
}

View File

@ -34,10 +34,10 @@ public class Roll extends DynamicCustomOp {
public Roll() {}
public Roll(@NonNull INDArray input, @NonNull INDArray axes, @NonNull INDArray shifts) {
public Roll(@NonNull INDArray input, @NonNull INDArray shifts, @NonNull INDArray axes) {
Preconditions.checkArgument(axes.rank() == shifts.rank(), "Roll: shifts and axes should be the same rank");
Preconditions.checkArgument(axes.length() == shifts.length(), "Roll: shifts and axes should be the same length");
addInputArgument(input, axes, shifts);
addInputArgument(input, shifts, axes);
}
public Roll(@NonNull INDArray input, int shift) {
@ -49,8 +49,8 @@ public class Roll extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input,shift});
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
super("", sameDiff, new SDVariable[]{input,axes,shift});
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable shift, @NonNull SDVariable axes) {
super("", sameDiff, new SDVariable[]{input,shift,axes});
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {

View File

@ -284,7 +284,10 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0));
result.add(outputType == null ? DataType.INT : outputType);
if(dArguments.isEmpty())
result.add(outputType == null ? DataType.INT : outputType);
else
result.add(dArguments.get(0));
return result;
}
}

View File

@ -20,8 +20,8 @@
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
@ -36,6 +36,17 @@
<name>nd4j-tests</name>
<properties>
<kotlin.version>1.4.30-M1</kotlin.version>
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
<junit.version>4.12</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<scala.binary.version>2.11</scala.binary.version>
@ -43,7 +54,213 @@
<maven.compiler.testSource>1.8</maven.compiler.testSource>
</properties>
<build>
<plugins>
<plugin>
<groupId>org.projectlombok</groupId>
<artifactId>lombok-maven-plugin</artifactId>
<version>1.18.12.0</version>
<executions>
<execution>
<id>delombok</id>
<phase>generate-sources</phase>
<goals>
<goal>delombok</goal>
</goals>
<configuration>
<formatPreferences>
<javaLangAsFQN>skip</javaLangAsFQN>
</formatPreferences>
<verbose>true</verbose>
</configuration>
</execution>
<execution>
<id>test-delombok</id>
<phase>generate-test-sources</phase>
<goals>
<goal>testDelombok</goal>
</goals>
<configuration>
<verbose>true</verbose>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals><goal>add-source</goal></goals>
<configuration>
<sources>
<source>src/main/stubs</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<configuration>
<shadedArtifactAttached>true</shadedArtifactAttached>
<createDependencyReducedPom>false</createDependencyReducedPom>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>org/datanucleus/**</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>1.4.30-M1</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<!-- Test Dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-test</artifactId>
<version>${kotlin.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>samediff-import-tensorflow</artifactId>

View File

@ -68,13 +68,22 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
* all tests will trigger an assumeFalse(..) that indicates
* the status of the test failing. No tests will run.
*/
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList();
public final static List<String> EXECUTE_ONLY_MODELS = Arrays.asList(
// "max_pool_with_argmax/int32_int64_padding_SAME",
// "fused_batch_norm/float32_nhwc",
// "max_pool_with_argmax/int64_int64_padding_SAME",
// "fused_batch_norm/float16_nhwc",
"roll/rank3_int32_axis",
"roll/rank3_int32_axis",
"roll/rank2_float32_zeroshift",
"roll/rank3_float64_axis"
);
public static final String[] IGNORE_REGEXES = new String[]{
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
"bincount/.*",
//Invalid test cases. Verified by running graph against actual TF.
"slogdet/.*",
//TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too
// Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod

View File

@ -881,7 +881,7 @@ public class CustomOpsTests extends BaseNd4jTest {
BitCast op = new BitCast(Nd4j.zeros(1,10), DataType.FLOAT.toInt(), out);
List<LongShapeDescriptor> lsd = op.calculateOutputShape();
assertEquals(1, lsd.size());
assertArrayEquals(new long[]{1,10}, lsd.get(0).getShape());
assertArrayEquals(new long[]{1,10,2}, lsd.get(0).getShape());
}
@Test
@ -942,21 +942,21 @@ public class CustomOpsTests extends BaseNd4jTest {
}
@Test
@Ignore("Failing with results that are close")
public void testFakeQuantAgainstTF_1() {
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
INDArray x = Nd4j.createFromArray(new double[]{ 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
INDArray min = Nd4j.createFromArray(new float[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
INDArray max = Nd4j.createFromArray(new float[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
INDArray min = Nd4j.createFromArray(new double[]{ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
INDArray max = Nd4j.createFromArray(new double[]{ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
INDArray expected = Nd4j.createFromArray(new double[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
INDArray out = Nd4j.createUninitialized(x.shape());
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max);
Nd4j.exec(op);
assertEquals(expected, out);
INDArray[] output = Nd4j.exec(op);
assertEquals(expected, output[0]);
}
@Test
@ -971,8 +971,7 @@ public class CustomOpsTests extends BaseNd4jTest {
@Test
public void testResizeBilinear1() {
INDArray x = Nd4j.rand(1, 2,3,4);
INDArray x = Nd4j.rand(1, 10,10,4);
INDArray z = Nd4j.createUninitialized(x.shape());
boolean align = false;
val op = new ResizeBilinear(x, z, 10, 10, align, false);
@ -1082,24 +1081,8 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray distance = Nd4j.scalar(0.f);
Nd4j.exec(new KnnMinDistance(point, lowest, highest, distance));
// System.out.println(distance);
}
@Ignore("2019/11/15 AS - https://github.com/eclipse/deeplearning4j/issues/8399")
@Test
public void testCropAndResize() {
INDArray image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1);
INDArray boxes = Nd4j.createFromArray(new float[]{1,2,3,4}).reshape(1,4);
INDArray box_indices = Nd4j.createFromArray(new int[]{1});
INDArray crop_size = Nd4j.createFromArray(new int[]{1,2}).reshape(1,2);
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
INDArray output = Nd4j.create(DataType.FLOAT, 2,2,1,1);
Nd4j.exec(new CropAndResize(image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
output));
}
@Test
public void testLayersDropoutFail() {
@ -1338,7 +1321,6 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]);
}
@Ignore("Failure AS 11.28.2019 - https://github.com/eclipse/deeplearning4j/issues/8453")
@Test
public void testRoll1() {
INDArray a = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f, 0.2309f});
@ -1346,6 +1328,10 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.7244f, 0.2309f, 0.7788f, 0.8012f});
assertEquals(expected, ret[0]);
INDArray matrix = Nd4j.create(new double[]{0.7788,0.8012,0.7244,0.2309,0.7271,0.1804,0.5056,0.8925}).reshape(2,4);
Roll roll2 = new Roll(matrix,Nd4j.scalar(0),Nd4j.scalar(1));
INDArray[] outputs = Nd4j.exec(roll2);
System.out.println(outputs[0]);
}
@Test

View File

@ -0,0 +1,118 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.custom
import junit.framework.Assert.assertEquals
import org.junit.Ignore
import org.junit.Test
import org.nd4j.linalg.api.buffer.DataType
import org.nd4j.linalg.api.ops.impl.image.CropAndResize
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.samediff.frameworkimport.tensorflow.*
import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFrameworkImporter
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
class CustomOpTensorflowInteropTests {
@Test
@Ignore("Tensorflow expects different shape")
fun testCropAndResize() {
val image = Nd4j.createUninitialized(DataType.FLOAT, 1, 2, 2, 1)
val boxes = Nd4j.createFromArray(*floatArrayOf(1f, 2f, 3f, 4f)).reshape(1, 4)
val box_indices = Nd4j.createFromArray(*intArrayOf(0))
val crop_size = Nd4j.createFromArray(*intArrayOf(1, 2)).reshape( 2)
val imageNode = NodeDef {
op = "Placeholder"
name = "image"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val boxesNode = NodeDef {
op = "Placeholder"
name = "boxes"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val boxIndicesNode = NodeDef {
op = "Placeholder"
name = "boxIndices"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_INT32
})
}
val cropSizesNode = NodeDef {
op = "Placeholder"
name = "cropSize"
Attribute("dtype", AttrValue {
type = org.tensorflow.framework.DataType.DT_INT32
})
}
val opNode = NodeDef {
op = "CropAndResize"
name = "output"
Input("image")
Input("boxes")
Input("boxIndices")
Input("cropSize")
Attribute("extrapolation_value", AttrValue {
f = 0.5f
})
Attribute("T", AttrValue {
type = org.tensorflow.framework.DataType.DT_FLOAT
})
}
val graph = GraphDef {
Node(imageNode)
Node(boxesNode)
Node(boxIndicesNode)
Node(cropSizesNode)
Node(opNode)
}
val importer = TensorflowFrameworkImporter()
val irGraph = TensorflowIRGraph(graph,importer.opDefList,importer.registry)
val runner = TensorflowIRGraphRunner(irGraph,listOf("image","boxes","boxIndices","cropSize"),listOf("output"))
val tfResult = runner.run(mapOf("image" to image,"boxes" to boxes,"boxIndices" to box_indices,"cropSize" to crop_size))
val outputArr = tfResult["output"]
//Output shape mismatch - TF [2, 2, 1, 1] vs SD: [1, 2, 1, 1]
val output = Nd4j.create(DataType.FLOAT, 2, 2, 1, 1)
Nd4j.exec(
CropAndResize(
image, boxes, box_indices, crop_size, CropAndResize.Method.BILINEAR, 0.5,
output
)
)
assertEquals(outputArr,output)
}
}

View File

@ -1,5 +1,2 @@
Variable/read,Variable/read
Variable_1/read,Variable_1/read
floordiv/x,floordiv/x
floordiv/y,floordiv/y
floordiv,floordiv
in_0/read,in_0/read
Roll,Roll

View File

@ -1,18 +1 @@
in_0/read,in_0/read
while/Enter,while/Enter
while/Enter_1,while/Enter_1
while/Merge,while/Merge
while/Merge_1,while/Merge_1
while/Less,while/Less
while/LoopCond,while/LoopCond
while/Switch,while/Switch
while/Switch:1,while/Switch
while/Switch_1,while/Switch_1
while/Switch_1:1,while/Switch_1
while/Identity,while/Identity
while/Exit,while/Exit
while/Identity_1,while/Identity_1
while/Exit_1,while/Exit_1
while/add,while/add
while/NextIteration_1,while/NextIteration_1
while/NextIteration,while/NextIteration
Sum,Sum

View File

@ -207,12 +207,6 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
OpMappingRegistry<GRAPH_TYPE, NODE_TYPE, OP_DEF_TYPE, TENSOR_TYPE,
DATA_TYPE, ATTR_DEF_TYPE, ATTR_VALUE_TYPE>): SameDiff {
/*
First, build an in-memory representation of the graph that allows us to build the graph incrementally
If we can build the graph incrementally, we can make sure that the added variables are set up with the correct
datatype and (once implemented) greedy shape inference
*/
/*
First, build an in-memory representation of the graph that allows us to build the graph incrementally
@ -442,19 +436,34 @@ open class ImportGraph <GRAPH_TYPE: GeneratedMessageV3,
val op = SameDiffOp.builder()
.name(name)
.op(df)
.inputsToOp(inNames) //.outputsOfOp(outNames) //We'll set this later
.controlDeps(controlDeps)
.build()
//take only up to the inputs that are specified in the node/
//this is for cases where node inputs is > intended number for ops
//a common example is when ops convert input ndarrays to integers or float inputs
val numInputsToTake = importInfo[name]!!.second.argDescriptorList.filter { input -> input.argType == OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR }
.size
op.inputsToOp = inNames.subList(0,numInputsToTake)
//add nodes/other pre processing in order for this node to work
var addToGraph = true
sd.ops[name] = op
defaultRunner.initAttributes(df, sd, importInfo[name]!!)
//clear out inputs for variables as well to reflect the actual graph structure
if(numInputsToTake < numInputs) {
for(i in numInputsToTake until numInputs) {
if(sd.hasVariable(nd.inputAt(i))) {
val currInputVar = sd.variables[nd.inputAt(i)]!!
currInputVar.inputsForOp.remove(op.name)
}
}
}
//cache attributes just in case we have any rules so we don't create the rules more than once
val attributes = mappingContext.nodeAttributesAsMap()
mappingContext.relevantPrehookRules().forEach { rule ->
rule.preProcess(op, sd,attributes)
}
defaultRunner.initAttributes(df, sd, importInfo[name]!!)
//add nodes/other post processing in order for this node to work

View File

@ -33,6 +33,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
Onnx.TensorProto.DataType.UINT64 -> return IRDataTypeValue.DT_UINT64
Onnx.TensorProto.DataType.UINT32 -> return IRDataTypeValue.DT_UINT32
Onnx.TensorProto.DataType.UINT16 -> return IRDataTypeValue.DT_UINT16
Onnx.TensorProto.DataType.INT8 -> return IRDataTypeValue.DT_INT8
Onnx.TensorProto.DataType.UINT8 -> return IRDataTypeValue.DT_UINT8
Onnx.TensorProto.DataType.FLOAT16 -> return IRDataTypeValue.DT_HALF
Onnx.TensorProto.DataType.STRING -> return IRDataTypeValue.DT_STRING
Onnx.TensorProto.DataType.FLOAT -> return IRDataTypeValue.DT_FLOAT
@ -60,9 +62,11 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
override fun nd4jDataType(): DataType {
when(this.dataType) {
Onnx.TensorProto.DataType.UINT64 -> return return DataType.INT64
Onnx.TensorProto.DataType.UINT32 -> return return DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return return DataType.INT16
Onnx.TensorProto.DataType.INT8 -> return DataType.INT8
Onnx.TensorProto.DataType.UINT8 -> return DataType.UINT8
Onnx.TensorProto.DataType.UINT64 -> return DataType.UINT64
Onnx.TensorProto.DataType.UINT32 -> return DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return DataType.INT16
Onnx.TensorProto.DataType.FLOAT16 -> return return DataType.FLOAT16
Onnx.TensorProto.DataType.STRING -> return return DataType.UTF8
Onnx.TensorProto.DataType.FLOAT -> return return DataType.FLOAT
@ -81,8 +85,8 @@ class OnnxIRDataType(inputDataType: Onnx.TensorProto.DataType): IRDataType<Onnx.
override fun nameSpaceDataType(): TensorNamespace.DataType {
when(this.dataType) {
Onnx.TensorProto.DataType.UINT64 -> return return TensorNamespace.DataType.INT64
Onnx.TensorProto.DataType.UINT32 -> return return TensorNamespace.DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return return TensorNamespace.DataType.INT16
Onnx.TensorProto.DataType.UINT32 -> return TensorNamespace.DataType.INT32
Onnx.TensorProto.DataType.UINT16 -> return TensorNamespace.DataType.INT16
Onnx.TensorProto.DataType.FLOAT16 -> return return TensorNamespace.DataType.FLOAT16
Onnx.TensorProto.DataType.STRING -> return return TensorNamespace.DataType.STRING
Onnx.TensorProto.DataType.FLOAT -> return TensorNamespace.DataType.FLOAT

View File

@ -1,5 +1,5 @@
Placeholder,data
Placeholder,partitions
DynamicPartition,output
Identity,out0
Identity,out1
Const,in_0
Const,Roll/shift
Const,Roll/axis
Identity,in_0/read
Roll,Roll

View File

@ -1,7 +1,5 @@
Const,in_0
Const,eye/ones
Const,Roll/shift
Const,Roll/axis
Identity,in_0/read
MatrixDiag,eye/diag
Add,Add
Svd,Svd
Abs,Abs
Roll,Roll

View File

@ -1,3 +1,2 @@
DynamicPartition,output
Identity,out0
Identity,out1
Identity,in_0/read
Roll,Roll

View File

@ -1,5 +1,2 @@
Identity,in_0/read
MatrixDiag,eye/diag
Add,Add
Svd,Svd
Abs,Abs
Roll,Roll

View File

@ -1,5 +1,5 @@
data
partitions
output
out0
out1
in_0
Roll/shift
Roll/axis
in_0/read
Roll

View File

@ -1,7 +1,5 @@
in_0
eye/ones
Roll/shift
Roll/axis
in_0/read
eye/diag
Add
Svd
Abs
Roll

View File

@ -358,7 +358,7 @@ val bitCast = TensorflowMappingProcess(
opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "Bitcast",
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input"))),
attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dataType" to "type")))
attributeMappingRules = listOf(dataTypeToInt(mutableMapOf("newType" to "type")), valueMapping(mutableMapOf("dtype" to "type")))
)
val bitwiseAnd = TensorflowMappingProcess(
@ -1070,7 +1070,7 @@ val fill = TensorflowMappingProcess(
inputFrameworkOpName = "Fill",
opMappingRegistry = tensorflowOpRegistry,
attributeMappingRules = listOf(convertNDArrayInputToNumericalAttr(mutableMapOf("value" to "value")),
dataTypeToInt(mutableMapOf("dtype" to "T")),
dataTypeToInt(mutableMapOf("outputDataType" to "T")),
valueMapping(mutableMapOf("dtype" to "T"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("shapeArray" to "dims")))
)
@ -1381,6 +1381,7 @@ val maxPoolArgmax = multipleNameMapping(
intConstant(inputName = "extraParam0",constantValue = 0 ,argumentIndex = 9)[0],
intConstant(inputName = "isNHWC",argumentIndex = 10,constantValue = 1 )[0],
intConstant(inputName = "sameMode",argumentIndex = 8,constantValue = 8 )[0],
valueMapping(mutableMapOf("dtype" to "T"))
)
,tensorflowOpRegistry = tensorflowOpRegistry
)
@ -1455,6 +1456,13 @@ val mirrorPadding = mapTensorNamesWithOp(inputFrameworkOpName = "MirrorPad",opNa
* v1 and 2 which do not have a scoreThreshold to map. V3 does.
*/
val matrixBandPart = mapTensorNamesWithOp(inputFrameworkOpName = "MatrixBandPart",opName = "matrix_band_part",
tensorNames = mutableMapOf("input" to "input","minLowerT" to "num_lower",
"maxUpperT" to "num_upper"),
attributeMappingRules = listOf()
,tensorflowOpRegistry = tensorflowOpRegistry)
val nonMaxSuppressionV1 = multipleNameMapping(inputFrameworkOpNames = listOf("NonMaxSuppression"),
opName = "non_max_suppression",
tensorNames = mutableMapOf("boxes" to "boxes","scales" to "scores",
@ -1654,7 +1662,8 @@ val randomCrop = mapTensorNamesWithOp(inputFrameworkOpName = "RandomCrop",opName
attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed")))
,tensorflowOpRegistry = tensorflowOpRegistry)
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf(),
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",
tensorNames = mutableMapOf(),
attributeMappingRules = listOf()
,tensorflowOpRegistry = tensorflowOpRegistry)
@ -1823,8 +1832,8 @@ val reverseSequence = multipleNameMapping(inputFrameworkOpNames = listOf("Revers
,tensorflowOpRegistry = tensorflowOpRegistry)
val roll = multipleNameMapping(inputFrameworkOpNames = listOf("Roll"),opName = "roll",
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift"))),
tensorNames = mutableMapOf("input" to "input","dimensions" to "axis","shiftsI" to "shift")
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("shift" to "shift","dimensions" to "axis"))),
tensorNames = mutableMapOf("input" to "input")
,tensorflowOpRegistry = tensorflowOpRegistry)
//TODO: verify usingLocking property, it's not showing up in descriptors
@ -1941,6 +1950,7 @@ val size = TensorflowMappingProcess(
opMappingRegistry = tensorflowOpRegistry,
inputFrameworkOpName = "Size",
opName = "size",
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "out_type"))),
tensorMappingRules = listOf(mappingNDArrayInputs(mutableMapOf("input" to "input")))
)

View File

@ -36,6 +36,10 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return IRDataTypeValue.DT_DOUBLE
DataType.DT_FLOAT, DataType.DT_FLOAT_REF -> return IRDataTypeValue.DT_FLOAT
DataType.DT_HALF, DataType.DT_HALF_REF -> return IRDataTypeValue.DT_HALF
DataType.DT_INT8, DataType.DT_INT8_REF -> return IRDataTypeValue.DT_INT8
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return IRDataTypeValue.DT_UINT8
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return IRDataTypeValue.DT_UINT16
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return IRDataTypeValue.DT_UINT32
DataType.DT_INT16, DataType.DT_INT16_REF -> return IRDataTypeValue.DT_INT16
DataType.DT_INT32, DataType.DT_INT32_REF -> return IRDataTypeValue.DT_INT32
DataType.DT_INT64, DataType.DT_INT64_REF -> return IRDataTypeValue.DT_INT64
@ -70,9 +74,11 @@ class TensorflowIRDataType(inputDataType: DataType): IRDataType<DataType> {
DataType.DT_BFLOAT16, DataType.DT_BFLOAT16_REF -> return org.nd4j.linalg.api.buffer.DataType.BFLOAT16
DataType.DT_INT64, DataType.DT_INT64_REF -> return org.nd4j.linalg.api.buffer.DataType.INT64
DataType.DT_HALF, DataType.DT_HALF_REF -> return org.nd4j.linalg.api.buffer.DataType.FLOAT16
DataType.DT_INT8, DataType.DT_INT8_REF -> return org.nd4j.linalg.api.buffer.DataType.INT8
DataType.DT_INT16, DataType.DT_INT16_REF -> return org.nd4j.linalg.api.buffer.DataType.INT16
DataType.DT_INT32, DataType.DT_INT32_REF -> return org.nd4j.linalg.api.buffer.DataType.INT32
DataType.DT_DOUBLE, DataType.DT_DOUBLE_REF -> return org.nd4j.linalg.api.buffer.DataType.DOUBLE
DataType.DT_UINT8, DataType.DT_UINT8_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT8
DataType.DT_UINT16, DataType.DT_UINT16_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT16
DataType.DT_UINT32, DataType.DT_UINT32_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT32
DataType.DT_UINT64, DataType.DT_UINT64_REF -> return org.nd4j.linalg.api.buffer.DataType.UINT64

View File

@ -134,7 +134,6 @@ class TensorflowIRGraph(graphDef: GraphDef, opDef: OpList
val node = nodeByName(varName)
val attrMap = node.attrMap
if(!attrMap.containsKey("dtype")) {
val retSet = attrMap.values.filter { attrValue -> attrValue.type != DataType.DT_INVALID }
if(retSet.isEmpty()) {
return TensorflowIRDataType(DataType.DT_INVALID)

View File

@ -1172,6 +1172,18 @@ mappings {
ruleType: "tensor"
inputFrameworkOpName: "Size"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "out_type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "out_type"
}
ruleType: "attribute"
inputFrameworkOpName: "Size"
}
}
mappings {
frameworkName: "tensorflow"
@ -2615,6 +2627,35 @@ mappings {
inputFrameworkOpName: "Polygamma"
}
}
mappings {
frameworkName: "tensorflow"
opName: "matrix_band_part"
inputFrameworkOpName: "MatrixBandPart"
rule {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "num_lower"
inputTensorName: "num_upper"
outputTensorName: "input"
outputTensorName: "minLowerT"
outputTensorName: "maxUpperT"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "minLowerT"
value: "num_lower"
}
inputToOutput {
key: "maxUpperT"
value: "num_upper"
}
ruleType: "tensor"
inputFrameworkOpName: "MatrixBandPart"
}
}
mappings {
frameworkName: "tensorflow"
opName: "equals"
@ -6427,8 +6468,9 @@ mappings {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dataType"
key: "dtype"
value: "type"
}
ruleType: "attribute"
@ -6834,23 +6876,11 @@ mappings {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "axis"
inputTensorName: "shift"
outputTensorName: "input"
outputTensorName: "dimensions"
outputTensorName: "shiftsI"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "dimensions"
value: "axis"
}
inputToOutput {
key: "shiftsI"
value: "shift"
}
ruleType: "tensor"
inputFrameworkOpName: "Roll"
}
@ -6862,6 +6892,10 @@ mappings {
key: "shift"
value: "shift"
}
inputToOutput {
key: "dimensions"
value: "axis"
}
ruleType: "attribute"
inputFrameworkOpName: "Roll"
}
@ -7251,6 +7285,7 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "start"
outputDoubleName: "stop"
inputToOutput {
key: "start"
@ -9237,6 +9272,8 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "on"
outputDoubleName: "off"
inputToOutput {
key: "on"
value: "on_value"
@ -10592,6 +10629,8 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "min"
outputDoubleName: "max"
inputToOutput {
key: "min"
value: "minval"
@ -12082,10 +12121,10 @@ mappings {
rule {
ruleName: "datatypetoint"
functionName: "datatypetoint"
outputIntName: "outputDataType"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
key: "outputDataType"
value: "T"
}
ruleType: "attribute"
@ -12352,6 +12391,18 @@ mappings {
}
inputFrameworkOpName: "MaxPoolWithArgmax"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "MaxPoolWithArgmax"
}
}
mappings {
frameworkName: "tensorflow"
@ -13428,6 +13479,9 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "from"
outputDoubleName: "to"
outputDoubleName: "step"
inputToOutput {
key: "from"
value: "start"
@ -14778,6 +14832,7 @@ mappings {
functionName: "ndarrayinputtonumericalattribute"
outputIntName: "maxOutputSize"
outputDoubleName: "overlapThreshold"
outputDoubleName: "scoreThreshold"
inputToOutput {
key: "maxOutputSize"
value: "max_output_size"
@ -15157,7 +15212,7 @@ mappings {
functionName: "stringequals"
inputStringAttrName: "padding"
inputStringAttrName: "padding"
outputBooleanName: "isSameMode"
outputIntName: "isSameMode"
inputToOutput {
key: "isSameMode"
value: "padding"

View File

@ -32,6 +32,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper
import org.nd4j.ir.OpNamespace
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.DynamicCustomOp
import org.nd4j.linalg.api.ops.custom.Roll
import org.nd4j.linalg.api.ops.impl.transforms.BinCount
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt
import org.nd4j.linalg.factory.Nd4j
@ -46,6 +47,7 @@ import org.nd4j.samediff.frameworkimport.tensorflow.importer.TensorflowFramework
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraph
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRGraphRunner
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRNode
import org.nd4j.samediff.frameworkimport.tensorflow.ir.TensorflowIRTensor
import org.nd4j.shade.protobuf.ByteString
import org.nd4j.shade.protobuf.TextFormat
import org.nd4j.tensorflow.conversion.graphrunner.GraphRunner
@ -96,6 +98,14 @@ class TestTensorflowIR {
val output = graph.outputAll(inputMap)
val output2 = importedGraph.outputAll(inputMap)
val matrix =
TensorflowIRTensor(tensorflowIRGraph.nodeByName("in_0").attrMap["value"]!!.tensor).toNd4jNDArray()
val roll2 = Roll(matrix, Nd4j.scalar(2), Nd4j.scalar(1))
val outputs = Nd4j.exec(roll2)[0]
val tfOutputRoll = tfOutput["Roll"]
val nd4jOutput = output["Roll"]
//assertEquals(tfOutput.keys,outputList)
//assertEquals(tfOutput.keys,output2.keys)
val names = tensorflowIRGraph.nodeList().map { input -> input.nodeName() }
@ -116,7 +126,7 @@ class TestTensorflowIR {
println(notEquals)
// assertEquals(output,output2)
// assertEquals(output,output2)
//assertEquals(tfOutput,output)
}

View File

@ -1172,6 +1172,18 @@ mappings {
ruleType: "tensor"
inputFrameworkOpName: "Size"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "out_type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "out_type"
}
ruleType: "attribute"
inputFrameworkOpName: "Size"
}
}
mappings {
frameworkName: "tensorflow"
@ -2615,6 +2627,35 @@ mappings {
inputFrameworkOpName: "Polygamma"
}
}
mappings {
frameworkName: "tensorflow"
opName: "matrix_band_part"
inputFrameworkOpName: "MatrixBandPart"
rule {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "num_lower"
inputTensorName: "num_upper"
outputTensorName: "input"
outputTensorName: "minLowerT"
outputTensorName: "maxUpperT"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "minLowerT"
value: "num_lower"
}
inputToOutput {
key: "maxUpperT"
value: "num_upper"
}
ruleType: "tensor"
inputFrameworkOpName: "MatrixBandPart"
}
}
mappings {
frameworkName: "tensorflow"
opName: "equals"
@ -6427,8 +6468,9 @@ mappings {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "type"
outputDataTypeName: "dtype"
inputToOutput {
key: "dataType"
key: "dtype"
value: "type"
}
ruleType: "attribute"
@ -6834,23 +6876,11 @@ mappings {
ruleName: "ndarraymapping"
functionName: "ndarraymapping"
inputTensorName: "input"
inputTensorName: "axis"
inputTensorName: "shift"
outputTensorName: "input"
outputTensorName: "dimensions"
outputTensorName: "shiftsI"
inputToOutput {
key: "input"
value: "input"
}
inputToOutput {
key: "dimensions"
value: "axis"
}
inputToOutput {
key: "shiftsI"
value: "shift"
}
ruleType: "tensor"
inputFrameworkOpName: "Roll"
}
@ -6862,6 +6892,10 @@ mappings {
key: "shift"
value: "shift"
}
inputToOutput {
key: "dimensions"
value: "axis"
}
ruleType: "attribute"
inputFrameworkOpName: "Roll"
}
@ -7251,6 +7285,7 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "start"
outputDoubleName: "stop"
inputToOutput {
key: "start"
@ -9237,6 +9272,8 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "on"
outputDoubleName: "off"
inputToOutput {
key: "on"
value: "on_value"
@ -10592,6 +10629,8 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "min"
outputDoubleName: "max"
inputToOutput {
key: "min"
value: "minval"
@ -12082,10 +12121,10 @@ mappings {
rule {
ruleName: "datatypetoint"
functionName: "datatypetoint"
outputIntName: "outputDataType"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
key: "outputDataType"
value: "T"
}
ruleType: "attribute"
@ -12352,6 +12391,18 @@ mappings {
}
inputFrameworkOpName: "MaxPoolWithArgmax"
}
rule {
ruleName: "valuemapping"
functionName: "valuemapping"
inputDataTypeName: "T"
outputDataTypeName: "dtype"
inputToOutput {
key: "dtype"
value: "T"
}
ruleType: "attribute"
inputFrameworkOpName: "MaxPoolWithArgmax"
}
}
mappings {
frameworkName: "tensorflow"
@ -13428,6 +13479,9 @@ mappings {
rule {
ruleName: "ndarrayinputtonumericalattribute"
functionName: "ndarrayinputtonumericalattribute"
outputDoubleName: "from"
outputDoubleName: "to"
outputDoubleName: "step"
inputToOutput {
key: "from"
value: "start"
@ -14778,6 +14832,7 @@ mappings {
functionName: "ndarrayinputtonumericalattribute"
outputIntName: "maxOutputSize"
outputDoubleName: "overlapThreshold"
outputDoubleName: "scoreThreshold"
inputToOutput {
key: "maxOutputSize"
value: "max_output_size"
@ -15157,7 +15212,7 @@ mappings {
functionName: "stringequals"
inputStringAttrName: "padding"
inputStringAttrName: "padding"
outputBooleanName: "isSameMode"
outputIntName: "isSameMode"
inputToOutput {
key: "isSameMode"
value: "padding"

View File

@ -5,26 +5,26 @@ node {
key: "value"
value {
tensor {
dtype: DT_FLOAT
dtype: DT_DOUBLE
tensor_shape {
dim {
size: 2
size: 4
}
dim {
size: 3
size: 5
}
dim {
size: 3
size: 4
}
}
tensor_content: "~^G?L\033M?\236p9?\220ol>\356%:?X\2708><q\001?b|d?\224\316\v?\314al?P@\257=,5K?\326\271(?\3566\016?`u#>0\024\236>\240{\036>\240h\360>"
tensor_content: "0m4\347\376y\315?\344\033;\004\236p\351?\026..t\357%\352? \306G>\322\023\247?\230\303\330E)\235\327?,5\313k\v\350\345?\334m\034\311\255s\321?0\024\236\222\260\272\321?@\321\340\331\240{\316?v|\234\234\223o\356?\244&\\\214\314W\332?\270\376rL\357\224\331?\030\216\353R\253\023\332?*A<\365%\024\344?h\f\326CH\250\351?0\320Y{\256\\\261?\366\226\304\200\016\350\343?@cr\031B\026\323?\300Q\006<\213\200\303?\f\377 \310\035\271\320?\300\325\277\177]\302\347?\270\337\237\302\321P\343?\256L\224~H]\351?0\243.\031\266\256\342?\200B\250\257\316j\301?\304\372\266\312\352\225\332?D\320\255\371\vB\327?\364\025&\270p\360\327?H:\241v\n\272\312?0-} \201!\323?l2\v\247;|\331?\320r}\'\v\372\341?\006u\364aS@\347?P6<\211\270\337\273?\024\340\006`\267\262\322?t\222JV\310\'\325?\264/\177\232\"]\322?\242\037L\006\215=\345?\270\376\302\274\332\310\323?R\2100\375\212\314\355?\300;\026W\321\210\230?\260?t#\314\203\322?\366=D:{\005\342?p_\v\2335\315\356?\344U\370\034\372\317\332? l1N\344\252\330?\354\327\003\223\206\215\321?^C\326M\353\234\345?\326\333u\025\2449\354?\264\000\334Z\177\326\353?\244\341\253\326!\272\345?\320\336\312`\255\005\311?\244u\220o\266\033\324?V\b\201\267\276\271\351?$\253\324_o3\356?\264:\260\357i\003\335?\300[\'7L8\256?02UU\205\214\265?\240\255\276\263r+\257? \303\207\022\3446\334?\220\032\360l\364(\324?\030\374\036.\217W\355?\340!;ay\034\257?\312\255)\371\227\333\346?F\233`e\300\335\343?\264>\261\354\324\345\331?\212v\025\026\265o\342?|\036\036F\364}\330?\034\203\363\362\364\204\322?\000:\260e\"\372\230?F\316\345\330\2577\343?\356uN6<b\355?6\216\263\234\243N\351?\306b\204\305\260\200\342?\210\213\262X/J\331?(\352\313\203=\217\312?\310\3327\231\362\a\345?\210\234<\035}\241\313?\224\'0\201\363V\344?\240\260-3\3655\345?"
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
@ -32,6 +32,12 @@ node {
name: "in_0/read"
op: "Identity"
input: "in_0"
attr {
key: "T"
value {
type: DT_DOUBLE
}
}
attr {
key: "_class"
value {
@ -40,96 +46,73 @@ node {
}
}
}
}
node {
name: "Roll/shift"
op: "Const"
attr {
key: "T"
key: "dtype"
value {
type: DT_FLOAT
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
}
node {
name: "eye/ones"
name: "Roll/axis"
op: "Const"
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
dtype: DT_INT32
tensor_shape {
dim {
size: 2
}
dim {
size: 3
}
}
float_val: 1.0
int_val: 1
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
type: DT_INT32
}
}
}
node {
name: "eye/diag"
op: "MatrixDiag"
input: "eye/ones"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Add"
op: "Add"
name: "Roll"
op: "Roll"
input: "in_0/read"
input: "eye/diag"
input: "Roll/shift"
input: "Roll/axis"
attr {
key: "T"
key: "Taxis"
value {
type: DT_FLOAT
}
}
}
node {
name: "Svd"
op: "Svd"
input: "Add"
attr {
key: "full_matrices"
value {
b: true
type: DT_INT32
}
}
attr {
key: "compute_uv"
key: "Tshift"
value {
b: false
type: DT_INT32
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Abs"
op: "Abs"
input: "Svd"
attr {
key: "T"
value {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
library {
}
}

View File

@ -1,4 +1,2 @@
output,output
output:1,output
out0,out0
out1,out1
in_0/read,in_0/read
Roll,Roll

View File

@ -1,5 +1,2 @@
in_0/read,in_0/read
eye/diag,eye/diag
Add,Add
Svd,Svd
Abs,Abs
Roll,Roll