Merge pull request #9213 from eclipse/ag_param_server_fixes
Fix compilation isssues with nd4j-parameter-servermaster
commit
ab87cae409
|
@ -58,6 +58,7 @@ fi
|
||||||
|
|
||||||
unameOut="$(uname)"
|
unameOut="$(uname)"
|
||||||
echo "$OSTYPE"
|
echo "$OSTYPE"
|
||||||
|
|
||||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||||
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
||||||
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||||
|
|
|
@ -1,285 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
|
||||||
|
|
||||||
import lombok.NonNull;
|
|
||||||
import lombok.val;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|
||||||
import org.nd4j.linalg.api.ops.OpContext;
|
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ScatterUpdate extends DynamicCustomOp {
|
|
||||||
protected CustomOp op;
|
|
||||||
|
|
||||||
// update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign
|
|
||||||
public enum UpdateOp {
|
|
||||||
ADD,
|
|
||||||
SUBTRACT,
|
|
||||||
MULTIPLY,
|
|
||||||
DIVIDE,
|
|
||||||
RSUBTRACT,
|
|
||||||
RDIVIDE,
|
|
||||||
ASSIGN,
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterUpdate(){ }
|
|
||||||
|
|
||||||
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
|
|
||||||
this(original, updates, null, indices, dimension, op);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, INDArray result, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
|
|
||||||
|
|
||||||
List<Integer> iargs = new ArrayList<>();
|
|
||||||
iargs.add(op.ordinal());
|
|
||||||
iargs.add(dimension.length);
|
|
||||||
for (val v: dimension)
|
|
||||||
iargs.add(v);
|
|
||||||
|
|
||||||
iargs.add(indices.length);
|
|
||||||
for (val v: indices)
|
|
||||||
iargs.add(v);
|
|
||||||
|
|
||||||
if (updates.tensorAlongDimension(0, dimension).length() != original.tensorAlongDimension(0, dimension).length())
|
|
||||||
throw new ND4JIllegalStateException("ScatterUpdate requires equal shaped tensors for operation along given dimension(s)");
|
|
||||||
|
|
||||||
long numTensors = original.tensorsAlongDimension(dimension);
|
|
||||||
for (val idx: indices)
|
|
||||||
if (idx >= numTensors)
|
|
||||||
throw new ND4JIllegalStateException("Can't update index higher then num tensors");
|
|
||||||
|
|
||||||
this.op = DynamicCustomOp.builder("scatter_update")
|
|
||||||
.addInputs(original, updates)
|
|
||||||
.callInplace(true)
|
|
||||||
.addIntegerArguments(iargs)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
|
||||||
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op;
|
|
||||||
return dynamicCustomOp.calculateOutputDataTypes(dataTypes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns op opName as string
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "scatter_update";
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns LongHash of the opName()
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public long opHash() {
|
|
||||||
return op.opHash();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method returns true if op is supposed to be executed inplace
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public boolean isInplaceCall() {
|
|
||||||
return op.isInplaceCall();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<INDArray> outputArguments() {
|
|
||||||
return op.outputArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<INDArray> inputArguments() {
|
|
||||||
return op.inputArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long[] iArgs() {
|
|
||||||
return op.iArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] tArgs() {
|
|
||||||
return op.tArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean[] bArgs() {
|
|
||||||
return op.bArgs();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(int... arg) {
|
|
||||||
op.addIArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addIArgument(long... arg) {
|
|
||||||
op.addIArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addBArgument(boolean... arg) {
|
|
||||||
op.addBArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeIArgument(Integer arg) {
|
|
||||||
op.removeIArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean getBArgument(int index) {
|
|
||||||
return op.getBArgument(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Long getIArgument(int index) {
|
|
||||||
return op.getIArgument(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numIArguments() {
|
|
||||||
return op.numIArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addTArgument(double... arg) {
|
|
||||||
op.addTArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeTArgument(Double arg) {
|
|
||||||
op.removeTArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double getTArgument(int index) {
|
|
||||||
return op.getTArgument(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numTArguments() {
|
|
||||||
return op.numTArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numBArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addInputArgument(INDArray... arg) {
|
|
||||||
op.addInputArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeInputArgument(INDArray arg) {
|
|
||||||
op.removeInputArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getInputArgument(int index) {
|
|
||||||
return op.getInputArgument(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numInputArguments() {
|
|
||||||
return op.numInputArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addOutputArgument(INDArray... arg) {
|
|
||||||
op.addOutputArgument(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeOutputArgument(INDArray arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getOutputArgument(int index) {
|
|
||||||
return op.getOutputArgument(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numOutputArguments() {
|
|
||||||
return op.numOutputArguments();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
|
|
||||||
return Nd4j.getExecutioner().calculateOutputShape(this, opContext);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public CustomOpDescriptor getDescriptor() {
|
|
||||||
return op.getDescriptor();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void assertValidForExecution() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DataType[] dArgs() {
|
|
||||||
return new DataType[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addDArgument(DataType... arg) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numDArguments() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void clearArrays() {
|
|
||||||
op.clearArrays();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -79,7 +79,7 @@ public class ScatterNdSub extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
||||||
inputDataTypes.get(0), inputDataTypes.get(2));
|
inputDataTypes.get(0), inputDataTypes.get(2));
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
||||||
inputDataTypes.get(0), inputDataTypes.get(2));
|
inputDataTypes.get(0), inputDataTypes.get(2));
|
||||||
|
|
|
@ -362,10 +362,10 @@
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
<LD_LIBRARY_PATH>
|
<LD_LIBRARY_PATH>
|
||||||
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes
|
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
|
||||||
</LD_LIBRARY_PATH>
|
</LD_LIBRARY_PATH>
|
||||||
<PATH>
|
<PATH>
|
||||||
${env.PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes
|
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
|
||||||
</PATH>
|
</PATH>
|
||||||
</environmentVariables>
|
</environmentVariables>
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
|
@ -391,7 +391,7 @@
|
||||||
|
|
||||||
For testing large zoo models, this may not be enough (so comment it out).
|
For testing large zoo models, this may not be enough (so comment it out).
|
||||||
-->
|
-->
|
||||||
<argLine>-Ddtype=float -Xmx8g</argLine>
|
<argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
@ -444,8 +444,11 @@
|
||||||
<configuration>
|
<configuration>
|
||||||
<environmentVariables>
|
<environmentVariables>
|
||||||
<LD_LIBRARY_PATH>
|
<LD_LIBRARY_PATH>
|
||||||
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/
|
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
||||||
</LD_LIBRARY_PATH>
|
</LD_LIBRARY_PATH>
|
||||||
|
<PATH>
|
||||||
|
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
||||||
|
</PATH>
|
||||||
</environmentVariables>
|
</environmentVariables>
|
||||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||||
<includes>
|
<includes>
|
||||||
|
@ -468,7 +471,7 @@
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g</argLine>
|
<argLine> -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
|
|
|
@ -76,20 +76,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
"layers_dropout/rank3_d05_train_mask1",
|
"layers_dropout/rank3_d05_train_mask1",
|
||||||
"layers_dropout/rank2_d09_train",
|
"layers_dropout/rank2_d09_train",
|
||||||
"layers_dropout/rank2_d05_train",*/
|
"layers_dropout/rank2_d05_train",*/
|
||||||
"reductions/scatter_update_vector",
|
|
||||||
"reductions/scatter_update_scalar",
|
|
||||||
"random_poisson/rank1_float16",
|
|
||||||
"random_poisson/rank1_float16",
|
|
||||||
"matrix_band_part/float64",
|
|
||||||
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
|
|
||||||
"bincount/rank0_weights",
|
|
||||||
"bincount/rank2_weights",
|
|
||||||
"scatter_nd_add/locking/rank1shape_1indices",
|
|
||||||
"scatter_nd_add/locking/rank2shape_1indices",
|
|
||||||
"scatter_nd_add/locking/rank3shape_1indices",
|
|
||||||
"scatter_nd_sub/locking/rank1shape_1indices",
|
|
||||||
"scatter_nd_sub/locking/rank2shape_1indices",
|
|
||||||
"scatter_nd_sub/locking/rank3shape_1indices"
|
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -97,10 +84,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||||
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
||||||
//Invalid test cases. Verified by running graph against actual TF.
|
//Invalid test cases. Verified by running graph against actual TF.
|
||||||
"compare_and_bitpack/.*",
|
"scatter_nd_sub/locking/rank1shape_1indices",
|
||||||
|
"reductions/scatter_update_vector",
|
||||||
|
"reductions/scatter_update_scalar",
|
||||||
|
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
|
||||||
|
"bincount/rank2_weights",
|
||||||
"slogdet/.*",
|
"slogdet/.*",
|
||||||
//IGNORE THIS: the TF results from comparing against an actual TF java run compared to this seem to be different.
|
|
||||||
"fused_batch_norm/float16_nhwc",
|
|
||||||
//Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent
|
//Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent
|
||||||
//These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable.
|
//These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable.
|
||||||
//We need different test cases here.
|
//We need different test cases here.
|
||||||
|
|
|
@ -56,7 +56,6 @@ import org.nd4j.linalg.api.ops.custom.RgbToHsv;
|
||||||
import org.nd4j.linalg.api.ops.custom.RgbToYiq;
|
import org.nd4j.linalg.api.ops.custom.RgbToYiq;
|
||||||
import org.nd4j.linalg.api.ops.custom.RgbToYuv;
|
import org.nd4j.linalg.api.ops.custom.RgbToYuv;
|
||||||
import org.nd4j.linalg.api.ops.custom.Roll;
|
import org.nd4j.linalg.api.ops.custom.Roll;
|
||||||
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
|
||||||
import org.nd4j.linalg.api.ops.custom.ToggleBits;
|
import org.nd4j.linalg.api.ops.custom.ToggleBits;
|
||||||
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
|
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
|
||||||
import org.nd4j.linalg.api.ops.custom.YiqToRgb;
|
import org.nd4j.linalg.api.ops.custom.YiqToRgb;
|
||||||
|
@ -64,12 +63,11 @@ import org.nd4j.linalg.api.ops.custom.YuvToRgb;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
|
import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
|
||||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
|
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||||
|
@ -390,51 +388,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testScatterUpdate1() {
|
|
||||||
val matrix = Nd4j.create(5, 5);
|
|
||||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
|
||||||
int[] dims = new int[]{1};
|
|
||||||
int[] indices = new int[]{1, 3};
|
|
||||||
|
|
||||||
val exp0 = Nd4j.create(5).assign(0);
|
|
||||||
val exp1 = Nd4j.create(5).assign(1);
|
|
||||||
|
|
||||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
|
|
||||||
assertEquals(exp0, matrix.getRow(0));
|
|
||||||
assertEquals(exp1, matrix.getRow(1));
|
|
||||||
assertEquals(exp0, matrix.getRow(2));
|
|
||||||
assertEquals(exp1, matrix.getRow(3));
|
|
||||||
assertEquals(exp0, matrix.getRow(4));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = ND4JIllegalStateException.class)
|
|
||||||
public void testScatterUpdate2() {
|
|
||||||
val matrix = Nd4j.create(5, 5);
|
|
||||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
|
||||||
int[] dims = new int[]{0};
|
|
||||||
int[] indices = new int[]{0, 1};
|
|
||||||
|
|
||||||
val exp0 = Nd4j.create(1, 5).assign(0);
|
|
||||||
val exp1 = Nd4j.create(1, 5).assign(1);
|
|
||||||
|
|
||||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test(expected = ND4JIllegalStateException.class)
|
|
||||||
public void testScatterUpdate3() {
|
|
||||||
val matrix = Nd4j.create(5, 5);
|
|
||||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
|
||||||
int[] dims = new int[]{1};
|
|
||||||
int[] indices = new int[]{0, 6};
|
|
||||||
|
|
||||||
val exp0 = Nd4j.create(1, 5).assign(0);
|
|
||||||
val exp1 = Nd4j.create(1, 5).assign(1);
|
|
||||||
|
|
||||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOpStatus1() {
|
public void testOpStatus1() {
|
||||||
|
@ -1005,17 +958,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, output);
|
assertEquals(expected, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCompareAndBitpack() {
|
|
||||||
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
|
|
||||||
-2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4);
|
|
||||||
INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4);
|
|
||||||
INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}).
|
|
||||||
reshape(2,3,4);
|
|
||||||
|
|
||||||
Nd4j.exec(new CompareAndBitpack(in ,2.0, out));
|
|
||||||
assertArrayEquals(new long[]{2,3,4}, out.shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDivideNoNan() {
|
public void testDivideNoNan() {
|
||||||
|
|
|
@ -730,20 +730,7 @@ public class NDBaseTest extends BaseNd4jTest {
|
||||||
assertEquals(y_exp, y);
|
assertEquals(y_exp, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testScatterUpdate() {
|
|
||||||
NDBase base = new NDBase();
|
|
||||||
|
|
||||||
//from testScatterOpGradients.
|
|
||||||
INDArray x = Nd4j.ones(DataType.DOUBLE, 10, 10);
|
|
||||||
INDArray indices = Nd4j.create(new double[]{3, 4, 5, 8, 9}).castTo(DataType.INT32);
|
|
||||||
INDArray updates = Nd4j.ones(DataType.DOUBLE, 5, 10).add(1.0);
|
|
||||||
INDArray y = base.scatterUpdate(x,indices, updates);
|
|
||||||
|
|
||||||
y = y.getColumn(0);
|
|
||||||
INDArray y_exp = Nd4j.createFromArray(1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0);
|
|
||||||
assertEquals(y_exp, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSegmentMax() {
|
public void testSegmentMax() {
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
|
||||||
@Before
|
@Before
|
||||||
public void before() throws Exception {
|
public void before() throws Exception {
|
||||||
final MediaDriver.Context ctx =
|
final MediaDriver.Context ctx =
|
||||||
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true)
|
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
|
||||||
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||||
|
@ -150,10 +150,10 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
|
||||||
|
|
||||||
private Aeron.Context getContext() {
|
private Aeron.Context getContext() {
|
||||||
if (ctx == null)
|
if (ctx == null)
|
||||||
ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void beforeClass() throws Exception {
|
public static void beforeClass() throws Exception {
|
||||||
final MediaDriver.Context ctx =
|
final MediaDriver.Context ctx =
|
||||||
new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true)
|
new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true)
|
||||||
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||||
|
@ -136,10 +136,10 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
|
||||||
|
|
||||||
private static Aeron.Context getContext() {
|
private static Aeron.Context getContext() {
|
||||||
if (ctx == null)
|
if (ctx == null)
|
||||||
ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000)
|
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,10 +119,10 @@ public class ParameterServerClientTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
private static Aeron.Context getContext() {
|
private static Aeron.Context getContext() {
|
||||||
return new Aeron.Context().publicationConnectionTimeout(-1)
|
return new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -424,7 +424,6 @@ public abstract class BaseTransport implements Transport {
|
||||||
CloseHelper.quietClose(subscriptionForShards);
|
CloseHelper.quietClose(subscriptionForShards);
|
||||||
CloseHelper.quietClose(subscriptionForClients);
|
CloseHelper.quietClose(subscriptionForClients);
|
||||||
CloseHelper.quietClose(aeron);
|
CloseHelper.quietClose(aeron);
|
||||||
CloseHelper.quietClose(context);
|
|
||||||
CloseHelper.quietClose(driver);
|
CloseHelper.quietClose(driver);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ public class RoutedTransport extends BaseTransport {
|
||||||
|
|
||||||
|
|
||||||
context = new Aeron.Context().driverTimeoutMs(30000)
|
context = new Aeron.Context().driverTimeoutMs(30000)
|
||||||
.keepAliveInterval(100000000);
|
.keepAliveIntervalNs(100000000);
|
||||||
AeronUtil.setDaemonizedThreadFactories(context);
|
AeronUtil.setDaemonizedThreadFactories(context);
|
||||||
|
|
||||||
MediaDriver.Context ctx = new MediaDriver.Context();
|
MediaDriver.Context ctx = new MediaDriver.Context();
|
||||||
|
@ -120,7 +120,6 @@ public class RoutedTransport extends BaseTransport {
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||||
CloseHelper.quietClose(aeron);
|
CloseHelper.quietClose(aeron);
|
||||||
CloseHelper.quietClose(driver);
|
CloseHelper.quietClose(driver);
|
||||||
CloseHelper.quietClose(context);
|
|
||||||
CloseHelper.quietClose(subscriptionForClients);
|
CloseHelper.quietClose(subscriptionForClients);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
@ -131,7 +131,7 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
|
||||||
splitter = MessageSplitter.getInstance();
|
splitter = MessageSplitter.getInstance();
|
||||||
|
|
||||||
context = new Aeron.Context().driverTimeoutMs(30000)
|
context = new Aeron.Context().driverTimeoutMs(30000)
|
||||||
.keepAliveInterval(100000000);
|
.keepAliveIntervalNs(100000000);
|
||||||
AeronUtil.setDaemonizedThreadFactories(context);
|
AeronUtil.setDaemonizedThreadFactories(context);
|
||||||
|
|
||||||
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context();
|
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context();
|
||||||
|
|
|
@ -184,7 +184,7 @@ public class ParameterServerNode implements AutoCloseable {
|
||||||
return new Aeron.Context()
|
return new Aeron.Context()
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -118,10 +118,10 @@ public class ParameterServerNodeTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
|
||||||
private static Aeron.Context getContext() {
|
private static Aeron.Context getContext() {
|
||||||
return new Aeron.Context().publicationConnectionTimeout(-1)
|
return new Aeron.Context().driverTimeoutMs(10000)
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,6 @@ import java.util.concurrent.locks.LockSupport;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
@Parameters(separators = ",")
|
@Parameters(separators = ",")
|
||||||
@Slf4j
|
|
||||||
public class ParameterServerSubscriber implements AutoCloseable {
|
public class ParameterServerSubscriber implements AutoCloseable {
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class);
|
private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class);
|
||||||
|
@ -255,9 +254,9 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
||||||
//Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
|
//Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
|
||||||
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
|
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
|
||||||
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED)
|
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED)
|
||||||
.dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false)
|
.dirDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false)
|
||||||
.ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength)
|
.ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength)
|
||||||
.maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy())
|
.conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||||
AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx);
|
AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx);
|
||||||
|
@ -380,10 +379,10 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
||||||
|
|
||||||
//get a context
|
//get a context
|
||||||
public Aeron.Context getContext() {
|
public Aeron.Context getContext() {
|
||||||
Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
Aeron.Context ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||||
.aeronDirectoryName(mediaDriverDirectoryName).keepAliveInterval(100000)
|
.aeronDirectoryName(mediaDriverDirectoryName).keepAliveIntervalNs(1000000)
|
||||||
.errorHandler(e -> log.error(e.toString(), e));
|
.errorHandler(e -> log.error(e.toString(), e));
|
||||||
AeronUtil.setDaemonizedThreadFactories(ctx);
|
AeronUtil.setDaemonizedThreadFactories(ctx);
|
||||||
return ctx;
|
return ctx;
|
||||||
|
|
|
@ -1860,13 +1860,13 @@ val scatterSub = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterSub"
|
||||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
//TODO: note: TF expects indices, we don't support them?
|
//TODO: note: TF expects indices, we don't support them?
|
||||||
val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_upd",
|
val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_update",
|
||||||
attributeMappingRules = listOf(),
|
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
|
||||||
tensorNames = mutableMapOf("input" to "ref","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry)
|
tensorNames = mutableMapOf("operand" to "ref","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
|
|
||||||
val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_upd",
|
val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_update",
|
||||||
attributeMappingRules = listOf(),
|
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
|
||||||
tensorNames = mutableMapOf("input" to "tensor","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry)
|
tensorNames = mutableMapOf("operand" to "tensor","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||||
//L2Loss
|
//L2Loss
|
||||||
val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss",
|
val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss",
|
||||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
|
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.shade.protobuf.GeneratedMessageV3
|
||||||
import org.nd4j.shade.protobuf.ProtocolMessageEnum
|
import org.nd4j.shade.protobuf.ProtocolMessageEnum
|
||||||
import org.nd4j.shade.protobuf.TextFormat
|
import org.nd4j.shade.protobuf.TextFormat
|
||||||
import org.tensorflow.framework.*
|
import org.tensorflow.framework.*
|
||||||
|
import java.lang.Exception
|
||||||
import java.nio.charset.Charset
|
import java.nio.charset.Charset
|
||||||
|
|
||||||
class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
|
class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
|
||||||
|
@ -87,7 +88,12 @@ class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
|
||||||
val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile)
|
val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile)
|
||||||
val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset())
|
val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset())
|
||||||
val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder()
|
val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder()
|
||||||
|
try {
|
||||||
TextFormat.merge(string,declarationBuilder)
|
TextFormat.merge(string,declarationBuilder)
|
||||||
|
} catch(e: Exception) {
|
||||||
|
println("Unable to parse mapper definitions for file file $fileName")
|
||||||
|
}
|
||||||
|
|
||||||
return declarationBuilder.build()
|
return declarationBuilder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15384,30 +15384,35 @@ mappings {
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "scatter_upd"
|
opName: "scatter_update"
|
||||||
inputFrameworkOpName: "ScatterUpdate"
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "ref"
|
inputTensorName: "ref"
|
||||||
inputTensorName: "updates"
|
inputTensorName: "updates"
|
||||||
inputTensorName: "indices"
|
outputTensorName: "operand"
|
||||||
outputTensorName: "input"
|
|
||||||
outputTensorName: "updates"
|
outputTensorName: "updates"
|
||||||
outputTensorName: "indices"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "operand"
|
||||||
value: "ref"
|
value: "ref"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "updates"
|
key: "updates"
|
||||||
value: "updates"
|
value: "updates"
|
||||||
}
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraytointattributevalue"
|
||||||
|
functionName: "ndarraytointattributevalue"
|
||||||
|
outputIntName: "indices"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "indices"
|
key: "indices"
|
||||||
value: "indices"
|
value: "indices"
|
||||||
}
|
}
|
||||||
ruleType: "tensor"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "ScatterUpdate"
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15527,30 +15532,35 @@ mappings {
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "scatter_upd"
|
opName: "scatter_update"
|
||||||
inputFrameworkOpName: "TensorScatterUpdate"
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "tensor"
|
inputTensorName: "tensor"
|
||||||
inputTensorName: "updates"
|
inputTensorName: "updates"
|
||||||
inputTensorName: "indices"
|
outputTensorName: "operand"
|
||||||
outputTensorName: "input"
|
|
||||||
outputTensorName: "updates"
|
outputTensorName: "updates"
|
||||||
outputTensorName: "indices"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "operand"
|
||||||
value: "tensor"
|
value: "tensor"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "updates"
|
key: "updates"
|
||||||
value: "updates"
|
value: "updates"
|
||||||
}
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraytointattributevalue"
|
||||||
|
functionName: "ndarraytointattributevalue"
|
||||||
|
outputIntName: "indices"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "indices"
|
key: "indices"
|
||||||
value: "indices"
|
value: "indices"
|
||||||
}
|
}
|
||||||
ruleType: "tensor"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "TensorScatterUpdate"
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15384,30 +15384,35 @@ mappings {
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "scatter_upd"
|
opName: "scatter_update"
|
||||||
inputFrameworkOpName: "ScatterUpdate"
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "ref"
|
inputTensorName: "ref"
|
||||||
inputTensorName: "updates"
|
inputTensorName: "updates"
|
||||||
inputTensorName: "indices"
|
outputTensorName: "operand"
|
||||||
outputTensorName: "input"
|
|
||||||
outputTensorName: "updates"
|
outputTensorName: "updates"
|
||||||
outputTensorName: "indices"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "operand"
|
||||||
value: "ref"
|
value: "ref"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "updates"
|
key: "updates"
|
||||||
value: "updates"
|
value: "updates"
|
||||||
}
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraytointattributevalue"
|
||||||
|
functionName: "ndarraytointattributevalue"
|
||||||
|
outputIntName: "indices"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "indices"
|
key: "indices"
|
||||||
value: "indices"
|
value: "indices"
|
||||||
}
|
}
|
||||||
ruleType: "tensor"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "ScatterUpdate"
|
inputFrameworkOpName: "ScatterUpdate"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15527,30 +15532,35 @@ mappings {
|
||||||
}
|
}
|
||||||
mappings {
|
mappings {
|
||||||
frameworkName: "tensorflow"
|
frameworkName: "tensorflow"
|
||||||
opName: "scatter_upd"
|
opName: "scatter_update"
|
||||||
inputFrameworkOpName: "TensorScatterUpdate"
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
rule {
|
rule {
|
||||||
ruleName: "ndarraymapping"
|
ruleName: "ndarraymapping"
|
||||||
functionName: "ndarraymapping"
|
functionName: "ndarraymapping"
|
||||||
inputTensorName: "tensor"
|
inputTensorName: "tensor"
|
||||||
inputTensorName: "updates"
|
inputTensorName: "updates"
|
||||||
inputTensorName: "indices"
|
outputTensorName: "operand"
|
||||||
outputTensorName: "input"
|
|
||||||
outputTensorName: "updates"
|
outputTensorName: "updates"
|
||||||
outputTensorName: "indices"
|
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "input"
|
key: "operand"
|
||||||
value: "tensor"
|
value: "tensor"
|
||||||
}
|
}
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "updates"
|
key: "updates"
|
||||||
value: "updates"
|
value: "updates"
|
||||||
}
|
}
|
||||||
|
ruleType: "tensor"
|
||||||
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
|
}
|
||||||
|
rule {
|
||||||
|
ruleName: "ndarraytointattributevalue"
|
||||||
|
functionName: "ndarraytointattributevalue"
|
||||||
|
outputIntName: "indices"
|
||||||
inputToOutput {
|
inputToOutput {
|
||||||
key: "indices"
|
key: "indices"
|
||||||
value: "indices"
|
value: "indices"
|
||||||
}
|
}
|
||||||
ruleType: "tensor"
|
ruleType: "attribute"
|
||||||
inputFrameworkOpName: "TensorScatterUpdate"
|
inputFrameworkOpName: "TensorScatterUpdate"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue