Fix compilation isssues with nd4j-parameter-server
parent
cbf1fad16c
commit
52f65d8511
|
@ -58,6 +58,7 @@ fi
|
|||
|
||||
unameOut="$(uname)"
|
||||
echo "$OSTYPE"
|
||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||
|
||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests.exe
|
||||
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
|
||||
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||
#[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/
|
||||
|
|
|
@ -1,285 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ScatterUpdate extends DynamicCustomOp {
|
||||
protected CustomOp op;
|
||||
|
||||
// update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign
|
||||
public enum UpdateOp {
|
||||
ADD,
|
||||
SUBTRACT,
|
||||
MULTIPLY,
|
||||
DIVIDE,
|
||||
RSUBTRACT,
|
||||
RDIVIDE,
|
||||
ASSIGN,
|
||||
}
|
||||
|
||||
public ScatterUpdate(){ }
|
||||
|
||||
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
|
||||
this(original, updates, null, indices, dimension, op);
|
||||
}
|
||||
|
||||
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, INDArray result, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
|
||||
|
||||
List<Integer> iargs = new ArrayList<>();
|
||||
iargs.add(op.ordinal());
|
||||
iargs.add(dimension.length);
|
||||
for (val v: dimension)
|
||||
iargs.add(v);
|
||||
|
||||
iargs.add(indices.length);
|
||||
for (val v: indices)
|
||||
iargs.add(v);
|
||||
|
||||
if (updates.tensorAlongDimension(0, dimension).length() != original.tensorAlongDimension(0, dimension).length())
|
||||
throw new ND4JIllegalStateException("ScatterUpdate requires equal shaped tensors for operation along given dimension(s)");
|
||||
|
||||
long numTensors = original.tensorsAlongDimension(dimension);
|
||||
for (val idx: indices)
|
||||
if (idx >= numTensors)
|
||||
throw new ND4JIllegalStateException("Can't update index higher then num tensors");
|
||||
|
||||
this.op = DynamicCustomOp.builder("scatter_update")
|
||||
.addInputs(original, updates)
|
||||
.callInplace(true)
|
||||
.addIntegerArguments(iargs)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op;
|
||||
return dynamicCustomOp.calculateOutputDataTypes(dataTypes);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns op opName as string
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public String opName() {
|
||||
return "scatter_update";
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns LongHash of the opName()
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public long opHash() {
|
||||
return op.opHash();
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns true if op is supposed to be executed inplace
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public boolean isInplaceCall() {
|
||||
return op.isInplaceCall();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<INDArray> outputArguments() {
|
||||
return op.outputArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<INDArray> inputArguments() {
|
||||
return op.inputArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] iArgs() {
|
||||
return op.iArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] tArgs() {
|
||||
return op.tArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] bArgs() {
|
||||
return op.bArgs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(int... arg) {
|
||||
op.addIArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addIArgument(long... arg) {
|
||||
op.addIArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addBArgument(boolean... arg) {
|
||||
op.addBArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeIArgument(Integer arg) {
|
||||
op.removeIArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean getBArgument(int index) {
|
||||
return op.getBArgument(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getIArgument(int index) {
|
||||
return op.getIArgument(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numIArguments() {
|
||||
return op.numIArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addTArgument(double... arg) {
|
||||
op.addTArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeTArgument(Double arg) {
|
||||
op.removeTArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double getTArgument(int index) {
|
||||
return op.getTArgument(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numTArguments() {
|
||||
return op.numTArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numBArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addInputArgument(INDArray... arg) {
|
||||
op.addInputArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeInputArgument(INDArray arg) {
|
||||
op.removeInputArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getInputArgument(int index) {
|
||||
return op.getInputArgument(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numInputArguments() {
|
||||
return op.numInputArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addOutputArgument(INDArray... arg) {
|
||||
op.addOutputArgument(arg);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void removeOutputArgument(INDArray arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getOutputArgument(int index) {
|
||||
return op.getOutputArgument(index);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numOutputArguments() {
|
||||
return op.numOutputArguments();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
return Nd4j.getExecutioner().calculateOutputShape(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
|
||||
return Nd4j.getExecutioner().calculateOutputShape(this, opContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CustomOpDescriptor getDescriptor() {
|
||||
return op.getDescriptor();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void assertValidForExecution() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType[] dArgs() {
|
||||
return new DataType[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addDArgument(DataType... arg) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numDArguments() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearArrays() {
|
||||
op.clearArrays();
|
||||
}
|
||||
}
|
|
@ -79,7 +79,7 @@ public class ScatterNdSub extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
||||
inputDataTypes.get(0), inputDataTypes.get(2));
|
||||
|
|
|
@ -79,7 +79,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
|
||||
inputDataTypes.get(0), inputDataTypes.get(2));
|
||||
|
|
|
@ -362,10 +362,10 @@
|
|||
<configuration>
|
||||
<environmentVariables>
|
||||
<LD_LIBRARY_PATH>
|
||||
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes
|
||||
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
|
||||
</LD_LIBRARY_PATH>
|
||||
<PATH>
|
||||
${env.PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes
|
||||
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
|
||||
</PATH>
|
||||
</environmentVariables>
|
||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||
|
@ -391,7 +391,7 @@
|
|||
|
||||
For testing large zoo models, this may not be enough (so comment it out).
|
||||
-->
|
||||
<argLine>-Ddtype=float -Xmx8g</argLine>
|
||||
<argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
@ -444,8 +444,11 @@
|
|||
<configuration>
|
||||
<environmentVariables>
|
||||
<LD_LIBRARY_PATH>
|
||||
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/
|
||||
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
||||
</LD_LIBRARY_PATH>
|
||||
<PATH>
|
||||
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
|
||||
</PATH>
|
||||
</environmentVariables>
|
||||
<testSourceDirectory>src/test/java</testSourceDirectory>
|
||||
<includes>
|
||||
|
@ -468,7 +471,7 @@
|
|||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||
Depending on a build machine, default value is not always enough.
|
||||
-->
|
||||
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g</argLine>
|
||||
<argLine> -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
|
|
|
@ -76,20 +76,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
"layers_dropout/rank3_d05_train_mask1",
|
||||
"layers_dropout/rank2_d09_train",
|
||||
"layers_dropout/rank2_d05_train",*/
|
||||
"reductions/scatter_update_vector",
|
||||
"reductions/scatter_update_scalar",
|
||||
"random_poisson/rank1_float16",
|
||||
"random_poisson/rank1_float16",
|
||||
"matrix_band_part/float64",
|
||||
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
|
||||
"bincount/rank0_weights",
|
||||
"bincount/rank2_weights",
|
||||
"scatter_nd_add/locking/rank1shape_1indices",
|
||||
"scatter_nd_add/locking/rank2shape_1indices",
|
||||
"scatter_nd_add/locking/rank3shape_1indices",
|
||||
"scatter_nd_sub/locking/rank1shape_1indices",
|
||||
"scatter_nd_sub/locking/rank2shape_1indices",
|
||||
"scatter_nd_sub/locking/rank3shape_1indices"
|
||||
|
||||
|
||||
);
|
||||
|
||||
|
@ -97,10 +84,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
|
||||
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
|
||||
//Invalid test cases. Verified by running graph against actual TF.
|
||||
"compare_and_bitpack/.*",
|
||||
"scatter_nd_sub/locking/rank1shape_1indices",
|
||||
"reductions/scatter_update_vector",
|
||||
"reductions/scatter_update_scalar",
|
||||
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
|
||||
"bincount/rank2_weights",
|
||||
"slogdet/.*",
|
||||
//IGNORE THIS: the TF results from comparing against an actual TF java run compared to this seem to be different.
|
||||
"fused_batch_norm/float16_nhwc",
|
||||
//Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent
|
||||
//These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable.
|
||||
//We need different test cases here.
|
||||
|
|
|
@ -56,7 +56,6 @@ import org.nd4j.linalg.api.ops.custom.RgbToHsv;
|
|||
import org.nd4j.linalg.api.ops.custom.RgbToYiq;
|
||||
import org.nd4j.linalg.api.ops.custom.RgbToYuv;
|
||||
import org.nd4j.linalg.api.ops.custom.Roll;
|
||||
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
||||
import org.nd4j.linalg.api.ops.custom.ToggleBits;
|
||||
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
|
||||
import org.nd4j.linalg.api.ops.custom.YiqToRgb;
|
||||
|
@ -64,12 +63,11 @@ import org.nd4j.linalg.api.ops.custom.YuvToRgb;
|
|||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Linspace;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||
|
@ -390,51 +388,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testScatterUpdate1() {
|
||||
val matrix = Nd4j.create(5, 5);
|
||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
||||
int[] dims = new int[]{1};
|
||||
int[] indices = new int[]{1, 3};
|
||||
|
||||
val exp0 = Nd4j.create(5).assign(0);
|
||||
val exp1 = Nd4j.create(5).assign(1);
|
||||
|
||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
|
||||
assertEquals(exp0, matrix.getRow(0));
|
||||
assertEquals(exp1, matrix.getRow(1));
|
||||
assertEquals(exp0, matrix.getRow(2));
|
||||
assertEquals(exp1, matrix.getRow(3));
|
||||
assertEquals(exp0, matrix.getRow(4));
|
||||
}
|
||||
|
||||
@Test(expected = ND4JIllegalStateException.class)
|
||||
public void testScatterUpdate2() {
|
||||
val matrix = Nd4j.create(5, 5);
|
||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
||||
int[] dims = new int[]{0};
|
||||
int[] indices = new int[]{0, 1};
|
||||
|
||||
val exp0 = Nd4j.create(1, 5).assign(0);
|
||||
val exp1 = Nd4j.create(1, 5).assign(1);
|
||||
|
||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
||||
}
|
||||
|
||||
@Test(expected = ND4JIllegalStateException.class)
|
||||
public void testScatterUpdate3() {
|
||||
val matrix = Nd4j.create(5, 5);
|
||||
val updates = Nd4j.create(2, 5).assign(1.0);
|
||||
int[] dims = new int[]{1};
|
||||
int[] indices = new int[]{0, 6};
|
||||
|
||||
val exp0 = Nd4j.create(1, 5).assign(0);
|
||||
val exp1 = Nd4j.create(1, 5).assign(1);
|
||||
|
||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOpStatus1() {
|
||||
|
@ -1005,17 +958,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertEquals(expected, output);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCompareAndBitpack() {
|
||||
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
|
||||
-2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4);
|
||||
INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4);
|
||||
INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}).
|
||||
reshape(2,3,4);
|
||||
|
||||
Nd4j.exec(new CompareAndBitpack(in ,2.0, out));
|
||||
assertArrayEquals(new long[]{2,3,4}, out.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDivideNoNan() {
|
||||
|
|
|
@ -730,20 +730,7 @@ public class NDBaseTest extends BaseNd4jTest {
|
|||
assertEquals(y_exp, y);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScatterUpdate() {
|
||||
NDBase base = new NDBase();
|
||||
|
||||
//from testScatterOpGradients.
|
||||
INDArray x = Nd4j.ones(DataType.DOUBLE, 10, 10);
|
||||
INDArray indices = Nd4j.create(new double[]{3, 4, 5, 8, 9}).castTo(DataType.INT32);
|
||||
INDArray updates = Nd4j.ones(DataType.DOUBLE, 5, 10).add(1.0);
|
||||
INDArray y = base.scatterUpdate(x,indices, updates);
|
||||
|
||||
y = y.getColumn(0);
|
||||
INDArray y_exp = Nd4j.createFromArray(1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0);
|
||||
assertEquals(y_exp, y);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSegmentMax() {
|
||||
|
|
|
@ -52,7 +52,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
|
|||
@Before
|
||||
public void before() throws Exception {
|
||||
final MediaDriver.Context ctx =
|
||||
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true)
|
||||
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
|
||||
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||
|
@ -150,10 +150,10 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
|
|||
|
||||
private Aeron.Context getContext() {
|
||||
if (ctx == null)
|
||||
ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
||||
ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
return ctx;
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
|
|||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
final MediaDriver.Context ctx =
|
||||
new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true)
|
||||
new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true)
|
||||
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||
|
@ -136,10 +136,10 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
|
|||
|
||||
private static Aeron.Context getContext() {
|
||||
if (ctx == null)
|
||||
ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
||||
ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
return ctx;
|
||||
}
|
||||
|
|
|
@ -119,10 +119,10 @@ public class ParameterServerClientTest extends BaseND4JTest {
|
|||
|
||||
|
||||
private static Aeron.Context getContext() {
|
||||
return new Aeron.Context().publicationConnectionTimeout(-1)
|
||||
return new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
}
|
||||
|
||||
|
|
|
@ -424,7 +424,6 @@ public abstract class BaseTransport implements Transport {
|
|||
CloseHelper.quietClose(subscriptionForShards);
|
||||
CloseHelper.quietClose(subscriptionForClients);
|
||||
CloseHelper.quietClose(aeron);
|
||||
CloseHelper.quietClose(context);
|
||||
CloseHelper.quietClose(driver);
|
||||
}
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ public class RoutedTransport extends BaseTransport {
|
|||
|
||||
|
||||
context = new Aeron.Context().driverTimeoutMs(30000)
|
||||
.keepAliveInterval(100000000);
|
||||
.keepAliveIntervalNs(100000000);
|
||||
AeronUtil.setDaemonizedThreadFactories(context);
|
||||
|
||||
MediaDriver.Context ctx = new MediaDriver.Context();
|
||||
|
@ -120,7 +120,6 @@ public class RoutedTransport extends BaseTransport {
|
|||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
CloseHelper.quietClose(aeron);
|
||||
CloseHelper.quietClose(driver);
|
||||
CloseHelper.quietClose(context);
|
||||
CloseHelper.quietClose(subscriptionForClients);
|
||||
}));
|
||||
|
||||
|
|
|
@ -131,7 +131,7 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
|
|||
splitter = MessageSplitter.getInstance();
|
||||
|
||||
context = new Aeron.Context().driverTimeoutMs(30000)
|
||||
.keepAliveInterval(100000000);
|
||||
.keepAliveIntervalNs(100000000);
|
||||
AeronUtil.setDaemonizedThreadFactories(context);
|
||||
|
||||
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context();
|
||||
|
|
|
@ -184,7 +184,7 @@ public class ParameterServerNode implements AutoCloseable {
|
|||
return new Aeron.Context()
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
}
|
||||
|
||||
|
|
|
@ -118,10 +118,10 @@ public class ParameterServerNodeTest extends BaseND4JTest {
|
|||
|
||||
|
||||
private static Aeron.Context getContext() {
|
||||
return new Aeron.Context().publicationConnectionTimeout(-1)
|
||||
return new Aeron.Context().driverTimeoutMs(10000)
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000)
|
||||
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,6 @@ import java.util.concurrent.locks.LockSupport;
|
|||
@NoArgsConstructor
|
||||
@Data
|
||||
@Parameters(separators = ",")
|
||||
@Slf4j
|
||||
public class ParameterServerSubscriber implements AutoCloseable {
|
||||
|
||||
private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class);
|
||||
|
@ -255,9 +254,9 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
|||
//Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
|
||||
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
|
||||
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED)
|
||||
.dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false)
|
||||
.dirDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false)
|
||||
.ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength)
|
||||
.maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||
.conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
||||
AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx);
|
||||
|
@ -380,10 +379,10 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
|||
|
||||
//get a context
|
||||
public Aeron.Context getContext() {
|
||||
Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1)
|
||||
Aeron.Context ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
|
||||
.availableImageHandler(AeronUtil::printAvailableImage)
|
||||
.unavailableImageHandler(AeronUtil::printUnavailableImage)
|
||||
.aeronDirectoryName(mediaDriverDirectoryName).keepAliveInterval(100000)
|
||||
.aeronDirectoryName(mediaDriverDirectoryName).keepAliveIntervalNs(1000000)
|
||||
.errorHandler(e -> log.error(e.toString(), e));
|
||||
AeronUtil.setDaemonizedThreadFactories(ctx);
|
||||
return ctx;
|
||||
|
|
|
@ -1860,13 +1860,13 @@ val scatterSub = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterSub"
|
|||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
|
||||
//TODO: note: TF expects indices, we don't support them?
|
||||
val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_upd",
|
||||
attributeMappingRules = listOf(),
|
||||
tensorNames = mutableMapOf("input" to "ref","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_update",
|
||||
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
|
||||
tensorNames = mutableMapOf("operand" to "ref","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
|
||||
val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_upd",
|
||||
attributeMappingRules = listOf(),
|
||||
tensorNames = mutableMapOf("input" to "tensor","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_update",
|
||||
attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
|
||||
tensorNames = mutableMapOf("operand" to "tensor","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
//L2Loss
|
||||
val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss",
|
||||
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.shade.protobuf.GeneratedMessageV3
|
|||
import org.nd4j.shade.protobuf.ProtocolMessageEnum
|
||||
import org.nd4j.shade.protobuf.TextFormat
|
||||
import org.tensorflow.framework.*
|
||||
import java.lang.Exception
|
||||
import java.nio.charset.Charset
|
||||
|
||||
class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
|
||||
|
@ -87,7 +88,12 @@ class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
|
|||
val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile)
|
||||
val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset())
|
||||
val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder()
|
||||
TextFormat.merge(string,declarationBuilder)
|
||||
try {
|
||||
TextFormat.merge(string,declarationBuilder)
|
||||
} catch(e: Exception) {
|
||||
println("Unable to parse mapper definitions for file file $fileName")
|
||||
}
|
||||
|
||||
return declarationBuilder.build()
|
||||
}
|
||||
|
||||
|
|
|
@ -15384,30 +15384,35 @@ mappings {
|
|||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "scatter_upd"
|
||||
opName: "scatter_update"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
rule {
|
||||
ruleName: "ndarraymapping"
|
||||
functionName: "ndarraymapping"
|
||||
inputTensorName: "ref"
|
||||
inputTensorName: "updates"
|
||||
inputTensorName: "indices"
|
||||
outputTensorName: "input"
|
||||
outputTensorName: "operand"
|
||||
outputTensorName: "updates"
|
||||
outputTensorName: "indices"
|
||||
inputToOutput {
|
||||
key: "input"
|
||||
key: "operand"
|
||||
value: "ref"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "updates"
|
||||
value: "updates"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
}
|
||||
rule {
|
||||
ruleName: "ndarraytointattributevalue"
|
||||
functionName: "ndarraytointattributevalue"
|
||||
outputIntName: "indices"
|
||||
inputToOutput {
|
||||
key: "indices"
|
||||
value: "indices"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
}
|
||||
}
|
||||
|
@ -15527,30 +15532,35 @@ mappings {
|
|||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "scatter_upd"
|
||||
opName: "scatter_update"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
rule {
|
||||
ruleName: "ndarraymapping"
|
||||
functionName: "ndarraymapping"
|
||||
inputTensorName: "tensor"
|
||||
inputTensorName: "updates"
|
||||
inputTensorName: "indices"
|
||||
outputTensorName: "input"
|
||||
outputTensorName: "operand"
|
||||
outputTensorName: "updates"
|
||||
outputTensorName: "indices"
|
||||
inputToOutput {
|
||||
key: "input"
|
||||
key: "operand"
|
||||
value: "tensor"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "updates"
|
||||
value: "updates"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
}
|
||||
rule {
|
||||
ruleName: "ndarraytointattributevalue"
|
||||
functionName: "ndarraytointattributevalue"
|
||||
outputIntName: "indices"
|
||||
inputToOutput {
|
||||
key: "indices"
|
||||
value: "indices"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15384,30 +15384,35 @@ mappings {
|
|||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "scatter_upd"
|
||||
opName: "scatter_update"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
rule {
|
||||
ruleName: "ndarraymapping"
|
||||
functionName: "ndarraymapping"
|
||||
inputTensorName: "ref"
|
||||
inputTensorName: "updates"
|
||||
inputTensorName: "indices"
|
||||
outputTensorName: "input"
|
||||
outputTensorName: "operand"
|
||||
outputTensorName: "updates"
|
||||
outputTensorName: "indices"
|
||||
inputToOutput {
|
||||
key: "input"
|
||||
key: "operand"
|
||||
value: "ref"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "updates"
|
||||
value: "updates"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
}
|
||||
rule {
|
||||
ruleName: "ndarraytointattributevalue"
|
||||
functionName: "ndarraytointattributevalue"
|
||||
outputIntName: "indices"
|
||||
inputToOutput {
|
||||
key: "indices"
|
||||
value: "indices"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "ScatterUpdate"
|
||||
}
|
||||
}
|
||||
|
@ -15527,30 +15532,35 @@ mappings {
|
|||
}
|
||||
mappings {
|
||||
frameworkName: "tensorflow"
|
||||
opName: "scatter_upd"
|
||||
opName: "scatter_update"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
rule {
|
||||
ruleName: "ndarraymapping"
|
||||
functionName: "ndarraymapping"
|
||||
inputTensorName: "tensor"
|
||||
inputTensorName: "updates"
|
||||
inputTensorName: "indices"
|
||||
outputTensorName: "input"
|
||||
outputTensorName: "operand"
|
||||
outputTensorName: "updates"
|
||||
outputTensorName: "indices"
|
||||
inputToOutput {
|
||||
key: "input"
|
||||
key: "operand"
|
||||
value: "tensor"
|
||||
}
|
||||
inputToOutput {
|
||||
key: "updates"
|
||||
value: "updates"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
}
|
||||
rule {
|
||||
ruleName: "ndarraytointattributevalue"
|
||||
functionName: "ndarraytointattributevalue"
|
||||
outputIntName: "indices"
|
||||
inputToOutput {
|
||||
key: "indices"
|
||||
value: "indices"
|
||||
}
|
||||
ruleType: "tensor"
|
||||
ruleType: "attribute"
|
||||
inputFrameworkOpName: "TensorScatterUpdate"
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue