Merge pull request #9213 from eclipse/ag_param_server_fixes

Fix compilation isssues with nd4j-parameter-server
master
Adam Gibson 2021-03-07 19:32:35 +09:00 committed by GitHub
commit ab87cae409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 92 additions and 431 deletions

View File

@ -58,6 +58,7 @@ fi
unameOut="$(uname)" unameOut="$(uname)"
echo "$OSTYPE" echo "$OSTYPE"
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) # Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion)
[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ [ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/

View File

@ -1,285 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.custom;
import lombok.NonNull;
import lombok.val;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
public class ScatterUpdate extends DynamicCustomOp {
protected CustomOp op;
// update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign
public enum UpdateOp {
ADD,
SUBTRACT,
MULTIPLY,
DIVIDE,
RSUBTRACT,
RDIVIDE,
ASSIGN,
}
public ScatterUpdate(){ }
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
this(original, updates, null, indices, dimension, op);
}
public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, INDArray result, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) {
List<Integer> iargs = new ArrayList<>();
iargs.add(op.ordinal());
iargs.add(dimension.length);
for (val v: dimension)
iargs.add(v);
iargs.add(indices.length);
for (val v: indices)
iargs.add(v);
if (updates.tensorAlongDimension(0, dimension).length() != original.tensorAlongDimension(0, dimension).length())
throw new ND4JIllegalStateException("ScatterUpdate requires equal shaped tensors for operation along given dimension(s)");
long numTensors = original.tensorsAlongDimension(dimension);
for (val idx: indices)
if (idx >= numTensors)
throw new ND4JIllegalStateException("Can't update index higher then num tensors");
this.op = DynamicCustomOp.builder("scatter_update")
.addInputs(original, updates)
.callInplace(true)
.addIntegerArguments(iargs)
.build();
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op;
return dynamicCustomOp.calculateOutputDataTypes(dataTypes);
}
/**
* This method returns op opName as string
*
* @return
*/
@Override
public String opName() {
return "scatter_update";
}
/**
* This method returns LongHash of the opName()
*
* @return
*/
@Override
public long opHash() {
return op.opHash();
}
/**
* This method returns true if op is supposed to be executed inplace
*
* @return
*/
@Override
public boolean isInplaceCall() {
return op.isInplaceCall();
}
@Override
public List<INDArray> outputArguments() {
return op.outputArguments();
}
@Override
public List<INDArray> inputArguments() {
return op.inputArguments();
}
@Override
public long[] iArgs() {
return op.iArgs();
}
@Override
public double[] tArgs() {
return op.tArgs();
}
@Override
public boolean[] bArgs() {
return op.bArgs();
}
@Override
public void addIArgument(int... arg) {
op.addIArgument(arg);
}
@Override
public void addIArgument(long... arg) {
op.addIArgument(arg);
}
@Override
public void addBArgument(boolean... arg) {
op.addBArgument(arg);
}
@Override
public void removeIArgument(Integer arg) {
op.removeIArgument(arg);
}
@Override
public Boolean getBArgument(int index) {
return op.getBArgument(index);
}
@Override
public Long getIArgument(int index) {
return op.getIArgument(index);
}
@Override
public int numIArguments() {
return op.numIArguments();
}
@Override
public void addTArgument(double... arg) {
op.addTArgument(arg);
}
@Override
public void removeTArgument(Double arg) {
op.removeTArgument(arg);
}
@Override
public Double getTArgument(int index) {
return op.getTArgument(index);
}
@Override
public int numTArguments() {
return op.numTArguments();
}
@Override
public int numBArguments() {
return 0;
}
@Override
public void addInputArgument(INDArray... arg) {
op.addInputArgument(arg);
}
@Override
public void removeInputArgument(INDArray arg) {
op.removeInputArgument(arg);
}
@Override
public INDArray getInputArgument(int index) {
return op.getInputArgument(index);
}
@Override
public int numInputArguments() {
return op.numInputArguments();
}
@Override
public void addOutputArgument(INDArray... arg) {
op.addOutputArgument(arg);
}
@Override
public void removeOutputArgument(INDArray arg) {
}
@Override
public INDArray getOutputArgument(int index) {
return op.getOutputArgument(index);
}
@Override
public int numOutputArguments() {
return op.numOutputArguments();
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
return Nd4j.getExecutioner().calculateOutputShape(this, opContext);
}
@Override
public CustomOpDescriptor getDescriptor() {
return op.getDescriptor();
}
@Override
public void assertValidForExecution() {
}
@Override
public DataType[] dArgs() {
return new DataType[0];
}
@Override
public void addDArgument(DataType... arg) {
}
@Override
public int numDArguments() {
return 0;
}
@Override
public void clearArrays() {
op.clearArrays();
}
}

View File

@ -79,7 +79,7 @@ public class ScatterNdSub extends DynamicCustomOp {
} }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s", Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
inputDataTypes.get(0), inputDataTypes.get(2)); inputDataTypes.get(0), inputDataTypes.get(2));

View File

@ -79,7 +79,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
} }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s", Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s",
inputDataTypes.get(0), inputDataTypes.get(2)); inputDataTypes.get(0), inputDataTypes.get(2));

View File

@ -362,10 +362,10 @@
<configuration> <configuration>
<environmentVariables> <environmentVariables>
<LD_LIBRARY_PATH> <LD_LIBRARY_PATH>
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
</LD_LIBRARY_PATH> </LD_LIBRARY_PATH>
<PATH> <PATH>
${env.PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes
</PATH> </PATH>
</environmentVariables> </environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory> <testSourceDirectory>src/test/java</testSourceDirectory>
@ -391,7 +391,7 @@
For testing large zoo models, this may not be enough (so comment it out). For testing large zoo models, this may not be enough (so comment it out).
--> -->
<argLine>-Ddtype=float -Xmx8g</argLine> <argLine>-Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes"</argLine>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>
@ -444,8 +444,11 @@
<configuration> <configuration>
<environmentVariables> <environmentVariables>
<LD_LIBRARY_PATH> <LD_LIBRARY_PATH>
${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
</LD_LIBRARY_PATH> </LD_LIBRARY_PATH>
<PATH>
${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes
</PATH>
</environmentVariables> </environmentVariables>
<testSourceDirectory>src/test/java</testSourceDirectory> <testSourceDirectory>src/test/java</testSourceDirectory>
<includes> <includes>
@ -468,7 +471,7 @@
Maximum heap size was set to 6g, as a minimum required value for tests run. Maximum heap size was set to 6g, as a minimum required value for tests run.
Depending on a build machine, default value is not always enough. Depending on a build machine, default value is not always enough.
--> -->
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g</argLine> <argLine> -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes"</argLine>
</configuration> </configuration>
</plugin> </plugin>
</plugins> </plugins>

View File

@ -76,20 +76,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"layers_dropout/rank3_d05_train_mask1", "layers_dropout/rank3_d05_train_mask1",
"layers_dropout/rank2_d09_train", "layers_dropout/rank2_d09_train",
"layers_dropout/rank2_d05_train",*/ "layers_dropout/rank2_d05_train",*/
"reductions/scatter_update_vector",
"reductions/scatter_update_scalar",
"random_poisson/rank1_float16",
"random_poisson/rank1_float16",
"matrix_band_part/float64",
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
"bincount/rank0_weights",
"bincount/rank2_weights",
"scatter_nd_add/locking/rank1shape_1indices",
"scatter_nd_add/locking/rank2shape_1indices",
"scatter_nd_add/locking/rank3shape_1indices",
"scatter_nd_sub/locking/rank1shape_1indices",
"scatter_nd_sub/locking/rank2shape_1indices",
"scatter_nd_sub/locking/rank3shape_1indices"
); );
@ -97,10 +84,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
//Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965
// Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance // Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance
//Invalid test cases. Verified by running graph against actual TF. //Invalid test cases. Verified by running graph against actual TF.
"compare_and_bitpack/.*", "scatter_nd_sub/locking/rank1shape_1indices",
"reductions/scatter_update_vector",
"reductions/scatter_update_scalar",
"emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates",
"bincount/rank2_weights",
"slogdet/.*", "slogdet/.*",
//IGNORE THIS: the TF results from comparing against an actual TF java run compared to this seem to be different.
"fused_batch_norm/float16_nhwc",
//Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent //Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent
//These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable. //These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable.
//We need different test cases here. //We need different test cases here.

View File

@ -56,7 +56,6 @@ import org.nd4j.linalg.api.ops.custom.RgbToHsv;
import org.nd4j.linalg.api.ops.custom.RgbToYiq; import org.nd4j.linalg.api.ops.custom.RgbToYiq;
import org.nd4j.linalg.api.ops.custom.RgbToYuv; import org.nd4j.linalg.api.ops.custom.RgbToYuv;
import org.nd4j.linalg.api.ops.custom.Roll; import org.nd4j.linalg.api.ops.custom.Roll;
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
import org.nd4j.linalg.api.ops.custom.ToggleBits; import org.nd4j.linalg.api.ops.custom.ToggleBits;
import org.nd4j.linalg.api.ops.custom.TriangularSolve; import org.nd4j.linalg.api.ops.custom.TriangularSolve;
import org.nd4j.linalg.api.ops.custom.YiqToRgb; import org.nd4j.linalg.api.ops.custom.YiqToRgb;
@ -64,12 +63,11 @@ import org.nd4j.linalg.api.ops.custom.YuvToRgb;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.shape.Linspace; import org.nd4j.linalg.api.ops.impl.shape.Linspace;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
@ -390,51 +388,6 @@ public class CustomOpsTests extends BaseNd4jTest {
} }
@Test
public void testScatterUpdate1() {
val matrix = Nd4j.create(5, 5);
val updates = Nd4j.create(2, 5).assign(1.0);
int[] dims = new int[]{1};
int[] indices = new int[]{1, 3};
val exp0 = Nd4j.create(5).assign(0);
val exp1 = Nd4j.create(5).assign(1);
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
Nd4j.getExecutioner().exec(op);
assertEquals(exp0, matrix.getRow(0));
assertEquals(exp1, matrix.getRow(1));
assertEquals(exp0, matrix.getRow(2));
assertEquals(exp1, matrix.getRow(3));
assertEquals(exp0, matrix.getRow(4));
}
@Test(expected = ND4JIllegalStateException.class)
public void testScatterUpdate2() {
val matrix = Nd4j.create(5, 5);
val updates = Nd4j.create(2, 5).assign(1.0);
int[] dims = new int[]{0};
int[] indices = new int[]{0, 1};
val exp0 = Nd4j.create(1, 5).assign(0);
val exp1 = Nd4j.create(1, 5).assign(1);
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
}
@Test(expected = ND4JIllegalStateException.class)
public void testScatterUpdate3() {
val matrix = Nd4j.create(5, 5);
val updates = Nd4j.create(2, 5).assign(1.0);
int[] dims = new int[]{1};
int[] indices = new int[]{0, 6};
val exp0 = Nd4j.create(1, 5).assign(0);
val exp1 = Nd4j.create(1, 5).assign(1);
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
}
@Test @Test
public void testOpStatus1() { public void testOpStatus1() {
@ -1005,17 +958,7 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, output); assertEquals(expected, output);
} }
@Test
public void testCompareAndBitpack() {
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,
-2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4);
INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4);
INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}).
reshape(2,3,4);
Nd4j.exec(new CompareAndBitpack(in ,2.0, out));
assertArrayEquals(new long[]{2,3,4}, out.shape());
}
@Test @Test
public void testDivideNoNan() { public void testDivideNoNan() {

View File

@ -730,20 +730,7 @@ public class NDBaseTest extends BaseNd4jTest {
assertEquals(y_exp, y); assertEquals(y_exp, y);
} }
@Test
public void testScatterUpdate() {
NDBase base = new NDBase();
//from testScatterOpGradients.
INDArray x = Nd4j.ones(DataType.DOUBLE, 10, 10);
INDArray indices = Nd4j.create(new double[]{3, 4, 5, 8, 9}).castTo(DataType.INT32);
INDArray updates = Nd4j.ones(DataType.DOUBLE, 5, 10).add(1.0);
INDArray y = base.scatterUpdate(x,indices, updates);
y = y.getColumn(0);
INDArray y_exp = Nd4j.createFromArray(1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0);
assertEquals(y_exp, y);
}
@Test @Test
public void testSegmentMax() { public void testSegmentMax() {

View File

@ -52,7 +52,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
@Before @Before
public void before() throws Exception { public void before() throws Exception {
final MediaDriver.Context ctx = final MediaDriver.Context ctx =
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true) new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy()); .senderIdleStrategy(new BusySpinIdleStrategy());
@ -150,10 +150,10 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
private Aeron.Context getContext() { private Aeron.Context getContext() {
if (ctx == null) if (ctx == null)
ctx = new Aeron.Context().publicationConnectionTimeout(-1) ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
return ctx; return ctx;
} }

View File

@ -51,7 +51,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
@BeforeClass @BeforeClass
public static void beforeClass() throws Exception { public static void beforeClass() throws Exception {
final MediaDriver.Context ctx = final MediaDriver.Context ctx =
new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true)
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy()); .senderIdleStrategy(new BusySpinIdleStrategy());
@ -136,10 +136,10 @@ public class ParameterServerClientPartialTest extends BaseND4JTest {
private static Aeron.Context getContext() { private static Aeron.Context getContext() {
if (ctx == null) if (ctx == null)
ctx = new Aeron.Context().publicationConnectionTimeout(-1) ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
return ctx; return ctx;
} }

View File

@ -119,10 +119,10 @@ public class ParameterServerClientTest extends BaseND4JTest {
private static Aeron.Context getContext() { private static Aeron.Context getContext() {
return new Aeron.Context().publicationConnectionTimeout(-1) return new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
} }

View File

@ -424,7 +424,6 @@ public abstract class BaseTransport implements Transport {
CloseHelper.quietClose(subscriptionForShards); CloseHelper.quietClose(subscriptionForShards);
CloseHelper.quietClose(subscriptionForClients); CloseHelper.quietClose(subscriptionForClients);
CloseHelper.quietClose(aeron); CloseHelper.quietClose(aeron);
CloseHelper.quietClose(context);
CloseHelper.quietClose(driver); CloseHelper.quietClose(driver);
} }

View File

@ -91,7 +91,7 @@ public class RoutedTransport extends BaseTransport {
context = new Aeron.Context().driverTimeoutMs(30000) context = new Aeron.Context().driverTimeoutMs(30000)
.keepAliveInterval(100000000); .keepAliveIntervalNs(100000000);
AeronUtil.setDaemonizedThreadFactories(context); AeronUtil.setDaemonizedThreadFactories(context);
MediaDriver.Context ctx = new MediaDriver.Context(); MediaDriver.Context ctx = new MediaDriver.Context();
@ -120,7 +120,6 @@ public class RoutedTransport extends BaseTransport {
Runtime.getRuntime().addShutdownHook(new Thread(() -> { Runtime.getRuntime().addShutdownHook(new Thread(() -> {
CloseHelper.quietClose(aeron); CloseHelper.quietClose(aeron);
CloseHelper.quietClose(driver); CloseHelper.quietClose(driver);
CloseHelper.quietClose(context);
CloseHelper.quietClose(subscriptionForClients); CloseHelper.quietClose(subscriptionForClients);
})); }));

View File

@ -131,7 +131,7 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable {
splitter = MessageSplitter.getInstance(); splitter = MessageSplitter.getInstance();
context = new Aeron.Context().driverTimeoutMs(30000) context = new Aeron.Context().driverTimeoutMs(30000)
.keepAliveInterval(100000000); .keepAliveIntervalNs(100000000);
AeronUtil.setDaemonizedThreadFactories(context); AeronUtil.setDaemonizedThreadFactories(context);
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context(); final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context();

View File

@ -184,7 +184,7 @@ public class ParameterServerNode implements AutoCloseable {
return new Aeron.Context() return new Aeron.Context()
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
} }

View File

@ -118,10 +118,10 @@ public class ParameterServerNodeTest extends BaseND4JTest {
private static Aeron.Context getContext() { private static Aeron.Context getContext() {
return new Aeron.Context().publicationConnectionTimeout(-1) return new Aeron.Context().driverTimeoutMs(10000)
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
} }

View File

@ -74,7 +74,6 @@ import java.util.concurrent.locks.LockSupport;
@NoArgsConstructor @NoArgsConstructor
@Data @Data
@Parameters(separators = ",") @Parameters(separators = ",")
@Slf4j
public class ParameterServerSubscriber implements AutoCloseable { public class ParameterServerSubscriber implements AutoCloseable {
private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class); private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class);
@ -255,9 +254,9 @@ public class ParameterServerSubscriber implements AutoCloseable {
//Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window. //Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength)); System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED) final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED)
.dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false) .dirDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false)
.ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength) .ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength)
.maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()) .conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy()); .senderIdleStrategy(new BusySpinIdleStrategy());
AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx); AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx);
@ -380,10 +379,10 @@ public class ParameterServerSubscriber implements AutoCloseable {
//get a context //get a context
public Aeron.Context getContext() { public Aeron.Context getContext() {
Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1) Aeron.Context ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE)
.availableImageHandler(AeronUtil::printAvailableImage) .availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriverDirectoryName).keepAliveInterval(100000) .aeronDirectoryName(mediaDriverDirectoryName).keepAliveIntervalNs(1000000)
.errorHandler(e -> log.error(e.toString(), e)); .errorHandler(e -> log.error(e.toString(), e));
AeronUtil.setDaemonizedThreadFactories(ctx); AeronUtil.setDaemonizedThreadFactories(ctx);
return ctx; return ctx;

View File

@ -1860,13 +1860,13 @@ val scatterSub = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterSub"
,tensorflowOpRegistry = tensorflowOpRegistry) ,tensorflowOpRegistry = tensorflowOpRegistry)
//TODO: note: TF expects indices, we don't support them? //TODO: note: TF expects indices, we don't support them?
val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_upd", val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_update",
attributeMappingRules = listOf(), attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
tensorNames = mutableMapOf("input" to "ref","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry) tensorNames = mutableMapOf("operand" to "ref","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_upd", val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_update",
attributeMappingRules = listOf(), attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))),
tensorNames = mutableMapOf("input" to "tensor","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry) tensorNames = mutableMapOf("operand" to "tensor","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry)
//L2Loss //L2Loss
val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss", val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss",
attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))), attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))),

View File

@ -32,6 +32,7 @@ import org.nd4j.shade.protobuf.GeneratedMessageV3
import org.nd4j.shade.protobuf.ProtocolMessageEnum import org.nd4j.shade.protobuf.ProtocolMessageEnum
import org.nd4j.shade.protobuf.TextFormat import org.nd4j.shade.protobuf.TextFormat
import org.tensorflow.framework.* import org.tensorflow.framework.*
import java.lang.Exception
import java.nio.charset.Charset import java.nio.charset.Charset
class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> { class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
@ -87,7 +88,12 @@ class TensorflowOpDescriptorLoader: OpDescriptorLoader<OpDef> {
val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile) val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile)
val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset()) val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset())
val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder() val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder()
TextFormat.merge(string,declarationBuilder) try {
TextFormat.merge(string,declarationBuilder)
} catch(e: Exception) {
println("Unable to parse mapper definitions for file file $fileName")
}
return declarationBuilder.build() return declarationBuilder.build()
} }

View File

@ -15384,30 +15384,35 @@ mappings {
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "scatter_upd" opName: "scatter_update"
inputFrameworkOpName: "ScatterUpdate" inputFrameworkOpName: "ScatterUpdate"
rule { rule {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "ref" inputTensorName: "ref"
inputTensorName: "updates" inputTensorName: "updates"
inputTensorName: "indices" outputTensorName: "operand"
outputTensorName: "input"
outputTensorName: "updates" outputTensorName: "updates"
outputTensorName: "indices"
inputToOutput { inputToOutput {
key: "input" key: "operand"
value: "ref" value: "ref"
} }
inputToOutput { inputToOutput {
key: "updates" key: "updates"
value: "updates" value: "updates"
} }
ruleType: "tensor"
inputFrameworkOpName: "ScatterUpdate"
}
rule {
ruleName: "ndarraytointattributevalue"
functionName: "ndarraytointattributevalue"
outputIntName: "indices"
inputToOutput { inputToOutput {
key: "indices" key: "indices"
value: "indices" value: "indices"
} }
ruleType: "tensor" ruleType: "attribute"
inputFrameworkOpName: "ScatterUpdate" inputFrameworkOpName: "ScatterUpdate"
} }
} }
@ -15527,30 +15532,35 @@ mappings {
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "scatter_upd" opName: "scatter_update"
inputFrameworkOpName: "TensorScatterUpdate" inputFrameworkOpName: "TensorScatterUpdate"
rule { rule {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "tensor" inputTensorName: "tensor"
inputTensorName: "updates" inputTensorName: "updates"
inputTensorName: "indices" outputTensorName: "operand"
outputTensorName: "input"
outputTensorName: "updates" outputTensorName: "updates"
outputTensorName: "indices"
inputToOutput { inputToOutput {
key: "input" key: "operand"
value: "tensor" value: "tensor"
} }
inputToOutput { inputToOutput {
key: "updates" key: "updates"
value: "updates" value: "updates"
} }
ruleType: "tensor"
inputFrameworkOpName: "TensorScatterUpdate"
}
rule {
ruleName: "ndarraytointattributevalue"
functionName: "ndarraytointattributevalue"
outputIntName: "indices"
inputToOutput { inputToOutput {
key: "indices" key: "indices"
value: "indices" value: "indices"
} }
ruleType: "tensor" ruleType: "attribute"
inputFrameworkOpName: "TensorScatterUpdate" inputFrameworkOpName: "TensorScatterUpdate"
} }
} }

View File

@ -15384,30 +15384,35 @@ mappings {
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "scatter_upd" opName: "scatter_update"
inputFrameworkOpName: "ScatterUpdate" inputFrameworkOpName: "ScatterUpdate"
rule { rule {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "ref" inputTensorName: "ref"
inputTensorName: "updates" inputTensorName: "updates"
inputTensorName: "indices" outputTensorName: "operand"
outputTensorName: "input"
outputTensorName: "updates" outputTensorName: "updates"
outputTensorName: "indices"
inputToOutput { inputToOutput {
key: "input" key: "operand"
value: "ref" value: "ref"
} }
inputToOutput { inputToOutput {
key: "updates" key: "updates"
value: "updates" value: "updates"
} }
ruleType: "tensor"
inputFrameworkOpName: "ScatterUpdate"
}
rule {
ruleName: "ndarraytointattributevalue"
functionName: "ndarraytointattributevalue"
outputIntName: "indices"
inputToOutput { inputToOutput {
key: "indices" key: "indices"
value: "indices" value: "indices"
} }
ruleType: "tensor" ruleType: "attribute"
inputFrameworkOpName: "ScatterUpdate" inputFrameworkOpName: "ScatterUpdate"
} }
} }
@ -15527,30 +15532,35 @@ mappings {
} }
mappings { mappings {
frameworkName: "tensorflow" frameworkName: "tensorflow"
opName: "scatter_upd" opName: "scatter_update"
inputFrameworkOpName: "TensorScatterUpdate" inputFrameworkOpName: "TensorScatterUpdate"
rule { rule {
ruleName: "ndarraymapping" ruleName: "ndarraymapping"
functionName: "ndarraymapping" functionName: "ndarraymapping"
inputTensorName: "tensor" inputTensorName: "tensor"
inputTensorName: "updates" inputTensorName: "updates"
inputTensorName: "indices" outputTensorName: "operand"
outputTensorName: "input"
outputTensorName: "updates" outputTensorName: "updates"
outputTensorName: "indices"
inputToOutput { inputToOutput {
key: "input" key: "operand"
value: "tensor" value: "tensor"
} }
inputToOutput { inputToOutput {
key: "updates" key: "updates"
value: "updates" value: "updates"
} }
ruleType: "tensor"
inputFrameworkOpName: "TensorScatterUpdate"
}
rule {
ruleName: "ndarraytointattributevalue"
functionName: "ndarraytointattributevalue"
outputIntName: "indices"
inputToOutput { inputToOutput {
key: "indices" key: "indices"
value: "indices" value: "indices"
} }
ruleType: "tensor" ruleType: "attribute"
inputFrameworkOpName: "TensorScatterUpdate" inputFrameworkOpName: "TensorScatterUpdate"
} }
} }