brian 656d367812 Fix javadoc and cleanup
Signed-off-by: brian <brian@brutex.de>
2022-10-21 15:19:32 +02:00

167 lines
5.8 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.aeron.ipc;
import io.aeron.Aeron;
import io.aeron.driver.MediaDriver;
import lombok.extern.slf4j.Slf4j;
import org.agrona.CloseHelper;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.junit.jupiter.api.Assertions.assertFalse;
@Slf4j
@Timeout(120)
public class LargeNdArrayIpcTest extends BaseND4JTest {
private MediaDriver mediaDriver;
private Aeron.Context ctx;
private final String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130));
private final int streamId = 10;
private final int length = (int) 1e7;
@Override
public long getTimeoutMilliseconds() {
return 12000L;
}
@BeforeEach
public void before() {
}
@AfterEach
public void after() {
CloseHelper.quietClose(mediaDriver);
}
@Test
public void testMultiThreadedIpcBig() throws Exception {
//MediaDriver.loadPropertiesFile("aeron.properties");
MediaDriver.Context mctx = AeronUtil.getMediaDriverContext(length).useWindowsHighResTimer(true);
mediaDriver = MediaDriver.launchEmbedded(mctx);
System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName());
System.out.println("Launched media driver");
int length = (int) 1e7;
INDArray arr = Nd4j.ones(length);
AeronNDArrayPublisher publisher;
ctx = new Aeron.Context()
.driverTimeoutMs(10000)
.availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName())
.errorHandler(err -> err.printStackTrace());
final AtomicBoolean running = new AtomicBoolean(true);
Aeron aeron = Aeron.connect(ctx);
int numSubscribers = 1;
AeronNDArraySubscriber[] subscribers = new AeronNDArraySubscriber[numSubscribers];
for (int i = 0; i < numSubscribers; i++) {
AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder()
.streamId(streamId).ctx(getContext())
.channel(channel)
.aeron(aeron)
.running(running)
.ndArrayCallback(new NDArrayCallback() {
/**
* A listener for ndarray message
*
* @param message the message for the callback
*/
@Override
public void onNDArrayMessage(NDArrayMessage message) {
running.set(false);
}
@Override
public void onNDArrayPartial(INDArray arr, long idx, int... dimensions) {
}
@Override
public void onNDArray(INDArray arr) {
running.set(false);
}
}).build();
Thread t = new Thread(() -> {
try {
subscriber.launch();
} catch (Exception e) {
System.out.println(e.getMessage());
e.printStackTrace();
}
});
t.start();
subscribers[i] = subscriber;
}
Thread.sleep(1000);
publisher = AeronNDArrayPublisher.builder()
.publishRetryTimeOut(3000)
.streamId(streamId)
.channel(channel)
.aeron(aeron)
.build();
for (int i = 0; i < 1; i++) {
System.out.println("About to send array.");
publisher.publish(arr);
System.out.println("Sent array");
}
Thread.sleep( 5000);
for (int i = 0; i < numSubscribers; i++)
CloseHelper.close(subscribers[i]);
CloseHelper.close(aeron);
CloseHelper.close(publisher);
assertFalse(running.get());
}
private Aeron.Context getContext() {
if (ctx == null)
ctx = new Aeron.Context().driverTimeoutMs(10000)
.availableImageHandler(AeronUtil::printAvailableImage)
.unavailableImageHandler(AeronUtil::printUnavailableImage)
.aeronDirectoryName(mediaDriver.aeronDirectoryName())
.errorHandler(err -> err.printStackTrace());
return ctx;
}
}