diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index 270a329c1..0fd823876 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -58,6 +58,7 @@ fi unameOut="$(uname)" echo "$OSTYPE" + ../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests # Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) [ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java deleted file mode 100644 index 170741c4e..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/ScatterUpdate.java +++ /dev/null @@ -1,285 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.nd4j.linalg.api.ops.custom; - -import lombok.NonNull; -import lombok.val; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.CustomOpDescriptor; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.OpContext; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -public class ScatterUpdate extends DynamicCustomOp { - protected CustomOp op; - - // update operation: 0 - add; 1 - sub; 2 - mul; 3 - div; 4 - rsub; 5 - rdiv; 6 - assign - public enum UpdateOp { - ADD, - SUBTRACT, - MULTIPLY, - DIVIDE, - RSUBTRACT, - RDIVIDE, - ASSIGN, - } - - public ScatterUpdate(){ } - - public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) { - this(original, updates, null, indices, dimension, op); - } - - public ScatterUpdate(@NonNull INDArray original, @NonNull INDArray updates, INDArray result, @NonNull int[] indices, int[] dimension, @NonNull UpdateOp op) { - - List iargs = new ArrayList<>(); - iargs.add(op.ordinal()); - iargs.add(dimension.length); - for (val v: dimension) - iargs.add(v); - - iargs.add(indices.length); - for (val v: indices) - iargs.add(v); - - if (updates.tensorAlongDimension(0, dimension).length() != original.tensorAlongDimension(0, dimension).length()) - throw new ND4JIllegalStateException("ScatterUpdate requires equal shaped tensors for operation along given dimension(s)"); - - long numTensors = original.tensorsAlongDimension(dimension); - for (val idx: indices) - if (idx >= numTensors) - throw new ND4JIllegalStateException("Can't update index higher then num tensors"); - - this.op = DynamicCustomOp.builder("scatter_update") - .addInputs(original, updates) - .callInplace(true) - .addIntegerArguments(iargs) - .build(); - } - - @Override - public List calculateOutputDataTypes(List dataTypes) { - DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op; - return dynamicCustomOp.calculateOutputDataTypes(dataTypes); - } - - /** - * This method returns op opName as string - * - * @return - */ - @Override - public String opName() { - return "scatter_update"; - } - - /** - * This method returns LongHash of the opName() - * - * @return - */ - @Override - public long opHash() { - return op.opHash(); - } - - /** - * This method returns true if op is supposed to be executed inplace - * - * @return - */ - @Override - public boolean isInplaceCall() { - return op.isInplaceCall(); - } - - @Override - public List outputArguments() { - return op.outputArguments(); - } - - @Override - public List inputArguments() { - return op.inputArguments(); - } - - @Override - public long[] iArgs() { - return op.iArgs(); - } - - @Override - public double[] tArgs() { - return op.tArgs(); - } - - @Override - public boolean[] bArgs() { - return op.bArgs(); - } - - @Override - public void addIArgument(int... arg) { - op.addIArgument(arg); - } - - @Override - public void addIArgument(long... arg) { - op.addIArgument(arg); - } - - @Override - public void addBArgument(boolean... arg) { - op.addBArgument(arg); - } - - @Override - public void removeIArgument(Integer arg) { - op.removeIArgument(arg); - } - - @Override - public Boolean getBArgument(int index) { - return op.getBArgument(index); - } - - @Override - public Long getIArgument(int index) { - return op.getIArgument(index); - } - - @Override - public int numIArguments() { - return op.numIArguments(); - } - - @Override - public void addTArgument(double... arg) { - op.addTArgument(arg); - } - - @Override - public void removeTArgument(Double arg) { - op.removeTArgument(arg); - } - - @Override - public Double getTArgument(int index) { - return op.getTArgument(index); - } - - @Override - public int numTArguments() { - return op.numTArguments(); - } - - @Override - public int numBArguments() { - return 0; - } - - @Override - public void addInputArgument(INDArray... arg) { - op.addInputArgument(arg); - } - - @Override - public void removeInputArgument(INDArray arg) { - op.removeInputArgument(arg); - } - - @Override - public INDArray getInputArgument(int index) { - return op.getInputArgument(index); - } - - @Override - public int numInputArguments() { - return op.numInputArguments(); - } - - @Override - public void addOutputArgument(INDArray... arg) { - op.addOutputArgument(arg); - } - - @Override - public void removeOutputArgument(INDArray arg) { - - } - - @Override - public INDArray getOutputArgument(int index) { - return op.getOutputArgument(index); - } - - @Override - public int numOutputArguments() { - return op.numOutputArguments(); - } - - @Override - public List calculateOutputShape() { - return Nd4j.getExecutioner().calculateOutputShape(this); - } - - @Override - public List calculateOutputShape(OpContext opContext) { - return Nd4j.getExecutioner().calculateOutputShape(this, opContext); - } - - @Override - public CustomOpDescriptor getDescriptor() { - return op.getDescriptor(); - } - - @Override - public void assertValidForExecution() { - - } - - @Override - public DataType[] dArgs() { - return new DataType[0]; - } - - @Override - public void addDArgument(DataType... arg) { - - } - - @Override - public int numDArguments() { - return 0; - } - - @Override - public void clearArrays() { - op.clearArrays(); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java index a7673f74f..546cb5055 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java @@ -79,7 +79,7 @@ public class ScatterNdSub extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s", inputDataTypes.get(0), inputDataTypes.get(2)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java index aec522998..825965fb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java @@ -79,7 +79,7 @@ public class ScatterNdUpdate extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes.get(0) == inputDataTypes.get(2), "Reference (input 0) and updates (input 2) must have exactly same data types, got %s and %s", inputDataTypes.get(0), inputDataTypes.get(2)); diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 591c71c29..eca5164d5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -362,10 +362,10 @@ - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes + ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes - ${env.PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/:${nd4j.native.basedir}/target/classes + ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes src/test/java @@ -391,7 +391,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Ddtype=float -Xmx8g + -Dfile.encoding=UTF-8 -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes" @@ -444,8 +444,11 @@ - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes + + ${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes + src/test/java @@ -468,7 +471,7 @@ Maximum heap size was set to 6g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g + -Dfile.encoding=UTF-8 -Dorg.bytedeco.javacpp.logger.debug=true -Djava.library.path="${nd4j.basedir}/nd4j-backends/nd4j-backend-impls/nd4j-cuda/target/classes" diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 5e8dcde70..0833b05f6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -76,20 +76,7 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a "layers_dropout/rank3_d05_train_mask1", "layers_dropout/rank2_d09_train", "layers_dropout/rank2_d05_train",*/ - "reductions/scatter_update_vector", - "reductions/scatter_update_scalar", - "random_poisson/rank1_float16", - "random_poisson/rank1_float16", - "matrix_band_part/float64", - "emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates", - "bincount/rank0_weights", - "bincount/rank2_weights", - "scatter_nd_add/locking/rank1shape_1indices", - "scatter_nd_add/locking/rank2shape_1indices", - "scatter_nd_add/locking/rank3shape_1indices", - "scatter_nd_sub/locking/rank1shape_1indices", - "scatter_nd_sub/locking/rank2shape_1indices", - "scatter_nd_sub/locking/rank3shape_1indices" + ); @@ -97,10 +84,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 // Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance //Invalid test cases. Verified by running graph against actual TF. - "compare_and_bitpack/.*", + "scatter_nd_sub/locking/rank1shape_1indices", + "reductions/scatter_update_vector", + "reductions/scatter_update_scalar", + "emptyArrayTests/scatter_update/rank1_emptyIndices_emptyUpdates", + "bincount/rank2_weights", "slogdet/.*", - //IGNORE THIS: the TF results from comparing against an actual TF java run compared to this seem to be different. - "fused_batch_norm/float16_nhwc", //Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent //These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable. //We need different test cases here. diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 9fa906693..a7fe39bec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -56,7 +56,6 @@ import org.nd4j.linalg.api.ops.custom.RgbToHsv; import org.nd4j.linalg.api.ops.custom.RgbToYiq; import org.nd4j.linalg.api.ops.custom.RgbToYuv; import org.nd4j.linalg.api.ops.custom.Roll; -import org.nd4j.linalg.api.ops.custom.ScatterUpdate; import org.nd4j.linalg.api.ops.custom.ToggleBits; import org.nd4j.linalg.api.ops.custom.TriangularSolve; import org.nd4j.linalg.api.ops.custom.YiqToRgb; @@ -64,12 +63,11 @@ import org.nd4j.linalg.api.ops.custom.YuvToRgb; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.controlflow.Where; -import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeArea; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; +import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.shape.Linspace; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; @@ -390,51 +388,6 @@ public class CustomOpsTests extends BaseNd4jTest { } - @Test - public void testScatterUpdate1() { - val matrix = Nd4j.create(5, 5); - val updates = Nd4j.create(2, 5).assign(1.0); - int[] dims = new int[]{1}; - int[] indices = new int[]{1, 3}; - - val exp0 = Nd4j.create(5).assign(0); - val exp1 = Nd4j.create(5).assign(1); - - ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); - Nd4j.getExecutioner().exec(op); - - assertEquals(exp0, matrix.getRow(0)); - assertEquals(exp1, matrix.getRow(1)); - assertEquals(exp0, matrix.getRow(2)); - assertEquals(exp1, matrix.getRow(3)); - assertEquals(exp0, matrix.getRow(4)); - } - - @Test(expected = ND4JIllegalStateException.class) - public void testScatterUpdate2() { - val matrix = Nd4j.create(5, 5); - val updates = Nd4j.create(2, 5).assign(1.0); - int[] dims = new int[]{0}; - int[] indices = new int[]{0, 1}; - - val exp0 = Nd4j.create(1, 5).assign(0); - val exp1 = Nd4j.create(1, 5).assign(1); - - ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); - } - - @Test(expected = ND4JIllegalStateException.class) - public void testScatterUpdate3() { - val matrix = Nd4j.create(5, 5); - val updates = Nd4j.create(2, 5).assign(1.0); - int[] dims = new int[]{1}; - int[] indices = new int[]{0, 6}; - - val exp0 = Nd4j.create(1, 5).assign(0); - val exp1 = Nd4j.create(1, 5).assign(1); - - ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD); - } @Test public void testOpStatus1() { @@ -1005,17 +958,7 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output); } - @Test - public void testCompareAndBitpack() { - INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, - -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4); - INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4); - INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}). - reshape(2,3,4); - Nd4j.exec(new CompareAndBitpack(in ,2.0, out)); - assertArrayEquals(new long[]{2,3,4}, out.shape()); - } @Test public void testDivideNoNan() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java index d76d0c8dd..0e642dcf6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/factory/ops/NDBaseTest.java @@ -730,20 +730,7 @@ public class NDBaseTest extends BaseNd4jTest { assertEquals(y_exp, y); } - @Test - public void testScatterUpdate() { - NDBase base = new NDBase(); - //from testScatterOpGradients. - INDArray x = Nd4j.ones(DataType.DOUBLE, 10, 10); - INDArray indices = Nd4j.create(new double[]{3, 4, 5, 8, 9}).castTo(DataType.INT32); - INDArray updates = Nd4j.ones(DataType.DOUBLE, 5, 10).add(1.0); - INDArray y = base.scatterUpdate(x,indices, updates); - - y = y.getColumn(0); - INDArray y_exp = Nd4j.createFromArray(1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0); - assertEquals(y_exp, y); - } @Test public void testSegmentMax() { diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 6fd129fbc..4f5a15cf6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -52,7 +52,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { @Before public void before() throws Exception { final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true) + new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) .senderIdleStrategy(new BusySpinIdleStrategy()); @@ -150,10 +150,10 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { private Aeron.Context getContext() { if (ctx == null) - ctx = new Aeron.Context().publicationConnectionTimeout(-1) + ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000) .errorHandler(e -> log.error(e.toString(), e)); return ctx; } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java index 44ddd8793..d2faa3982 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java @@ -51,7 +51,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { @BeforeClass public static void beforeClass() throws Exception { final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnStart(true) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) .senderIdleStrategy(new BusySpinIdleStrategy()); @@ -136,10 +136,10 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { private static Aeron.Context getContext() { if (ctx == null) - ctx = new Aeron.Context().publicationConnectionTimeout(-1) + ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000) .errorHandler(e -> log.error(e.toString(), e)); return ctx; } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java index 8e7e12128..6c8564d1c 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java @@ -119,10 +119,10 @@ public class ParameterServerClientTest extends BaseND4JTest { private static Aeron.Context getContext() { - return new Aeron.Context().publicationConnectionTimeout(-1) + return new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) .errorHandler(e -> log.error(e.toString(), e)); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java index 2c78b9ed8..bad3a3fb4 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java @@ -424,7 +424,6 @@ public abstract class BaseTransport implements Transport { CloseHelper.quietClose(subscriptionForShards); CloseHelper.quietClose(subscriptionForClients); CloseHelper.quietClose(aeron); - CloseHelper.quietClose(context); CloseHelper.quietClose(driver); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java index def534a9f..6ab7f1544 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java @@ -91,7 +91,7 @@ public class RoutedTransport extends BaseTransport { context = new Aeron.Context().driverTimeoutMs(30000) - .keepAliveInterval(100000000); + .keepAliveIntervalNs(100000000); AeronUtil.setDaemonizedThreadFactories(context); MediaDriver.Context ctx = new MediaDriver.Context(); @@ -120,7 +120,6 @@ public class RoutedTransport extends BaseTransport { Runtime.getRuntime().addShutdownHook(new Thread(() -> { CloseHelper.quietClose(aeron); CloseHelper.quietClose(driver); - CloseHelper.quietClose(context); CloseHelper.quietClose(subscriptionForClients); })); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java index 45d3f40ee..30b001340 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java @@ -131,7 +131,7 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable { splitter = MessageSplitter.getInstance(); context = new Aeron.Context().driverTimeoutMs(30000) - .keepAliveInterval(100000000); + .keepAliveIntervalNs(100000000); AeronUtil.setDaemonizedThreadFactories(context); final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/node/ParameterServerNode.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/node/ParameterServerNode.java index c166fa6cc..3b65ee134 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/node/ParameterServerNode.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/node/ParameterServerNode.java @@ -184,7 +184,7 @@ public class ParameterServerNode implements AutoCloseable { return new Aeron.Context() .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) .errorHandler(e -> log.error(e.toString(), e)); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java index be22cdb84..d5aed6eae 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java @@ -118,10 +118,10 @@ public class ParameterServerNodeTest extends BaseND4JTest { private static Aeron.Context getContext() { - return new Aeron.Context().publicationConnectionTimeout(-1) + return new Aeron.Context().driverTimeoutMs(10000) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(100000) .errorHandler(e -> log.error(e.toString(), e)); } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index 8cd60049d..258b765ed 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -74,7 +74,6 @@ import java.util.concurrent.locks.LockSupport; @NoArgsConstructor @Data @Parameters(separators = ",") -@Slf4j public class ParameterServerSubscriber implements AutoCloseable { private static Logger log = LoggerFactory.getLogger(ParameterServerSubscriber.class); @@ -255,9 +254,9 @@ public class ParameterServerSubscriber implements AutoCloseable { //Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window. System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength)); final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED) - .dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false) + .dirDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false) .ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength) - .maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()) + .conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) .senderIdleStrategy(new BusySpinIdleStrategy()); AeronUtil.setDaemonizedThreadFactories(mediaDriverCtx); @@ -380,10 +379,10 @@ public class ParameterServerSubscriber implements AutoCloseable { //get a context public Aeron.Context getContext() { - Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1) + Aeron.Context ctx = new Aeron.Context().driverTimeoutMs(Long.MAX_VALUE) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriverDirectoryName).keepAliveInterval(100000) + .aeronDirectoryName(mediaDriverDirectoryName).keepAliveIntervalNs(1000000) .errorHandler(e -> log.error(e.toString(), e)); AeronUtil.setDaemonizedThreadFactories(ctx); return ctx; diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt index 7e1da4029..a8a13be46 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt @@ -1860,13 +1860,13 @@ val scatterSub = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterSub" ,tensorflowOpRegistry = tensorflowOpRegistry) //TODO: note: TF expects indices, we don't support them? -val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_upd", - attributeMappingRules = listOf(), - tensorNames = mutableMapOf("input" to "ref","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry) +val scatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("ScatterUpdate"),opName = "scatter_update", + attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))), + tensorNames = mutableMapOf("operand" to "ref","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry) -val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_upd", - attributeMappingRules = listOf(), - tensorNames = mutableMapOf("input" to "tensor","updates" to "updates","indices" to "indices"),tensorflowOpRegistry = tensorflowOpRegistry) +val tensorScatterUpdate = multipleNameMapping(inputFrameworkOpNames = listOf("TensorScatterUpdate"),opName = "scatter_update", + attributeMappingRules = listOf(ndarrayToIntList(mutableMapOf("indices" to "indices"))), + tensorNames = mutableMapOf("operand" to "tensor","updates" to "updates"),tensorflowOpRegistry = tensorflowOpRegistry) //L2Loss val l2Loss = multipleNameMapping(inputFrameworkOpNames = listOf("L2Loss"),opName = "l2_loss", attributeMappingRules = listOf(valueMapping(mutableMapOf("dtype" to "T"))), diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt index 9038fcf22..0c269630c 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt @@ -32,6 +32,7 @@ import org.nd4j.shade.protobuf.GeneratedMessageV3 import org.nd4j.shade.protobuf.ProtocolMessageEnum import org.nd4j.shade.protobuf.TextFormat import org.tensorflow.framework.* +import java.lang.Exception import java.nio.charset.Charset class TensorflowOpDescriptorLoader: OpDescriptorLoader { @@ -87,7 +88,12 @@ class TensorflowOpDescriptorLoader: OpDescriptorLoader { val fileName = System.getProperty(tensorflowRulesetSpecifierProperty, tensorflowMappingRulSetDefaultFile) val string = IOUtils.toString(ClassPathResource(fileName).inputStream, Charset.defaultCharset()) val declarationBuilder = MapperNamespace.MappingDefinitionSet.newBuilder() - TextFormat.merge(string,declarationBuilder) + try { + TextFormat.merge(string,declarationBuilder) + } catch(e: Exception) { + println("Unable to parse mapper definitions for file file $fileName") + } + return declarationBuilder.build() } diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt index ecfd3b869..400b9233a 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt @@ -15384,30 +15384,35 @@ mappings { } mappings { frameworkName: "tensorflow" - opName: "scatter_upd" + opName: "scatter_update" inputFrameworkOpName: "ScatterUpdate" rule { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "ref" inputTensorName: "updates" - inputTensorName: "indices" - outputTensorName: "input" + outputTensorName: "operand" outputTensorName: "updates" - outputTensorName: "indices" inputToOutput { - key: "input" + key: "operand" value: "ref" } inputToOutput { key: "updates" value: "updates" } + ruleType: "tensor" + inputFrameworkOpName: "ScatterUpdate" + } + rule { + ruleName: "ndarraytointattributevalue" + functionName: "ndarraytointattributevalue" + outputIntName: "indices" inputToOutput { key: "indices" value: "indices" } - ruleType: "tensor" + ruleType: "attribute" inputFrameworkOpName: "ScatterUpdate" } } @@ -15527,30 +15532,35 @@ mappings { } mappings { frameworkName: "tensorflow" - opName: "scatter_upd" + opName: "scatter_update" inputFrameworkOpName: "TensorScatterUpdate" rule { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "tensor" inputTensorName: "updates" - inputTensorName: "indices" - outputTensorName: "input" + outputTensorName: "operand" outputTensorName: "updates" - outputTensorName: "indices" inputToOutput { - key: "input" + key: "operand" value: "tensor" } inputToOutput { key: "updates" value: "updates" } + ruleType: "tensor" + inputFrameworkOpName: "TensorScatterUpdate" + } + rule { + ruleName: "ndarraytointattributevalue" + functionName: "ndarraytointattributevalue" + outputIntName: "indices" inputToOutput { key: "indices" value: "indices" } - ruleType: "tensor" + ruleType: "attribute" inputFrameworkOpName: "TensorScatterUpdate" } } diff --git a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt index ecfd3b869..400b9233a 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt @@ -15384,30 +15384,35 @@ mappings { } mappings { frameworkName: "tensorflow" - opName: "scatter_upd" + opName: "scatter_update" inputFrameworkOpName: "ScatterUpdate" rule { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "ref" inputTensorName: "updates" - inputTensorName: "indices" - outputTensorName: "input" + outputTensorName: "operand" outputTensorName: "updates" - outputTensorName: "indices" inputToOutput { - key: "input" + key: "operand" value: "ref" } inputToOutput { key: "updates" value: "updates" } + ruleType: "tensor" + inputFrameworkOpName: "ScatterUpdate" + } + rule { + ruleName: "ndarraytointattributevalue" + functionName: "ndarraytointattributevalue" + outputIntName: "indices" inputToOutput { key: "indices" value: "indices" } - ruleType: "tensor" + ruleType: "attribute" inputFrameworkOpName: "ScatterUpdate" } } @@ -15527,30 +15532,35 @@ mappings { } mappings { frameworkName: "tensorflow" - opName: "scatter_upd" + opName: "scatter_update" inputFrameworkOpName: "TensorScatterUpdate" rule { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "tensor" inputTensorName: "updates" - inputTensorName: "indices" - outputTensorName: "input" + outputTensorName: "operand" outputTensorName: "updates" - outputTensorName: "indices" inputToOutput { - key: "input" + key: "operand" value: "tensor" } inputToOutput { key: "updates" value: "updates" } + ruleType: "tensor" + inputFrameworkOpName: "TensorScatterUpdate" + } + rule { + ruleName: "ndarraytointattributevalue" + functionName: "ndarraytointattributevalue" + outputIntName: "indices" inputToOutput { key: "indices" value: "indices" } - ruleType: "tensor" + ruleType: "attribute" inputFrameworkOpName: "TensorScatterUpdate" } }