* SameDiff execution memory management improvements, round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clear node outputs closed array references; Slight change to OpValidation internals to not rely on cached op outputs Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next step Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add WeakIdentityHashmap Signed-off-by: AlexDBlack <blacka101@gmail.com> * Session fixes for control ops and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for training session + in-line updating Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix losses and history during training Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd and other fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Don't use SDVariable.getArr() in TFGraphTestAllHelper (import tests) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * First steps for new dependency tracking approach Signed-off-by: AlexDBlack <blacka101@gmail.com> * Start integrating dependency tracking for memory management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Non-control op dependency tracking works/passes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch/merge Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue dependency tracking for initial variables/constants Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add check for aliases when determining if safe to close array Signed-off-by: AlexDBlack <blacka101@gmail.com> * First pass on new TF graph import class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Import fixes, op fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and fixes for new TF import mapper Signed-off-by: AlexDBlack <blacka101@gmail.com> * More cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Partial implementation of new dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * AbstractDependencyTracker for shared code Signed-off-by: AlexDBlack <blacka101@gmail.com> * Overhaul SameDiff graph execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, cleanup, next steps Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ad no-op memory manager, cleanup, fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix switch dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * INDArray.toString: no exception on closed arrays, just note closed Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix enter and exit dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * TensorArray memory management fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add unique ID for INDArray instances Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix memory management for NextIteration outputs in multi-iteration loops Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove (now unnecessary) special case handling for nested enters Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Handle control dependencies during execution; javadoc for memory managers Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup, polish, code comments, javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and more javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add memory validation for all TF import tests - ensure all arrays (except outputs) are released Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up arrays waiting on unexecuted ops at the end of execution Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for enter op memory managent in the context of multiple non-nested loops/frames Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix order of operation issues for dependency tracker Signed-off-by: AlexDBlack <blacka101@gmail.com> * Always clear op fields after execution to avoid leaks or unintended array reuse Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Re-implement dtype conversion Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for control dependencies execution (dependency tracking) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix TF import overrides and filtering Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for constant enter array dependency tracking Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More DL4J fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup and polish Signed-off-by: AlexDBlack <blacka101@gmail.com> * More polish and javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * More logging level tweaks, small DL4J fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix to DL4J SameDiffLayer Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix empty array deserialization, add extra deserialization checks Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers control dep serialization fixes; test serialization as part of all TF import tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Variable control dependencies serialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with removing inputs for ops Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * FlatBuffers NDArray deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final cleanup/polish Signed-off-by: AlexDBlack <blacka101@gmail.com>
208 lines
7.3 KiB
Java
208 lines
7.3 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.nd4j.graph;
|
|
|
|
import com.google.flatbuffers.FlatBufferBuilder;
|
|
import io.grpc.ManagedChannel;
|
|
import io.grpc.ManagedChannelBuilder;
|
|
import lombok.NonNull;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.val;
|
|
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
|
import org.nd4j.autodiff.execution.input.OperandsAdapter;
|
|
import org.nd4j.autodiff.execution.input.Operands;
|
|
import org.nd4j.autodiff.samediff.SameDiff;
|
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
|
import org.nd4j.graph.grpc.GraphInferenceServerGrpc;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.primitives.Pair;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
/**
|
|
* This class is a wrapper over GraphServer gRPC complex
|
|
*
|
|
* @author raver119@gmail.com
|
|
*/
|
|
@Slf4j
|
|
public class GraphInferenceGrpcClient {
|
|
private final ManagedChannel channel;
|
|
private final GraphInferenceServerGrpc.GraphInferenceServerBlockingStub blockingStub;
|
|
|
|
/**
|
|
* This method creates new GraphInferenceGrpcClient, with plain text connection
|
|
* @param host
|
|
* @param port
|
|
*/
|
|
public GraphInferenceGrpcClient(@NonNull String host, int port) {
|
|
this(host, port, false);
|
|
}
|
|
|
|
/**
|
|
* This method creates new GraphInferenceGrpcClient, with optional TLS support
|
|
* @param host
|
|
* @param port
|
|
*/
|
|
public GraphInferenceGrpcClient(@NonNull String host, int port, boolean useTLS) {
|
|
this(useTLS ? ManagedChannelBuilder.forAddress(host, port).build()
|
|
: ManagedChannelBuilder.forAddress(host, port).usePlaintext().build());
|
|
}
|
|
|
|
/**
|
|
* This method creates new GraphInferenceGrpcClient over given ManagedChannel
|
|
* @param channel
|
|
*/
|
|
public GraphInferenceGrpcClient(@NonNull ManagedChannel channel) {
|
|
this.channel = channel;
|
|
this.blockingStub = GraphInferenceServerGrpc.newBlockingStub(this.channel);
|
|
}
|
|
|
|
/**
|
|
* This method shuts down gRPC connection
|
|
*
|
|
* @throws InterruptedException
|
|
*/
|
|
public void shutdown() throws InterruptedException {
|
|
this.channel.shutdown().awaitTermination(10, TimeUnit.SECONDS);
|
|
}
|
|
|
|
/**
|
|
* This method adds given graph to the GraphServer storage
|
|
* @param graph
|
|
*/
|
|
public void registerGraph(@NonNull SameDiff graph) {
|
|
blockingStub.registerGraph(graph.asFlatGraph(false));
|
|
}
|
|
|
|
/**
|
|
* This method adds given graph to the GraphServer storage
|
|
*
|
|
* PLEASE NOTE: You don't need to register graph more then once
|
|
* PLEASE NOTE: You don't need to register graph if GraphServer was used with -f argument
|
|
* @param graphId id of the graph, if not 0 - should be used in subsequent output() requests
|
|
* @param graph
|
|
*
|
|
*/
|
|
public void registerGraph(long graphId, @NonNull SameDiff graph, ExecutorConfiguration configuration) {
|
|
val g = graph.asFlatGraph(graphId, configuration, false);
|
|
val v = blockingStub.registerGraph(g);
|
|
if (v.status() != 0)
|
|
throw new ND4JIllegalStateException("registerGraph() gRPC call failed");
|
|
}
|
|
|
|
/**
|
|
* This method sends inference request to the GraphServer instance, and returns result as array of INDArrays
|
|
*
|
|
* PLEASE NOTE: This call will be routed to default graph with id 0
|
|
* @param inputs graph inputs with their string ides
|
|
* @return
|
|
*/
|
|
public INDArray[] output(Pair<String, INDArray>... inputs) {
|
|
return output(0, inputs);
|
|
}
|
|
|
|
/**
|
|
* This method is suited for use of custom OperandsAdapters
|
|
* @param adapter
|
|
* @param <T>
|
|
* @return
|
|
*/
|
|
public <T> T output(long graphId, T value, OperandsAdapter<T> adapter) {
|
|
return adapter.output(this.output(graphId, adapter.input(value)));
|
|
}
|
|
|
|
|
|
public Operands output(long graphId, @NonNull Operands operands) {
|
|
val result = new ArrayList<INDArray>();
|
|
val builder = new FlatBufferBuilder(1024);
|
|
|
|
val ins = new int[operands.size()];
|
|
|
|
val col = operands.asCollection();
|
|
|
|
int cnt = 0;
|
|
for (val input: col) {
|
|
val id = input.getFirst();
|
|
val array = input.getSecond();
|
|
|
|
val idPair = IntPair.createIntPair(builder, id.getId(), id.getIndex());
|
|
val nameOff = id.getName() != null ? builder.createString(id.getName()) : 0;
|
|
|
|
val arrOff = array.toFlatArray(builder);
|
|
byte variableType = 0; //TODO is this OK here?
|
|
val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType, 0, 0, 0);
|
|
ins[cnt++] = varOff;
|
|
}
|
|
|
|
val varsOff = FlatInferenceRequest.createVariablesVector(builder, ins);
|
|
|
|
val off = FlatInferenceRequest.createFlatInferenceRequest(builder, graphId, varsOff, 0);
|
|
builder.finish(off);
|
|
|
|
val req = FlatInferenceRequest.getRootAsFlatInferenceRequest(builder.dataBuffer());
|
|
|
|
val fr = blockingStub.inferenceRequest(req);
|
|
|
|
val res = new Operands();
|
|
|
|
for (int e = 0; e < fr.variablesLength(); e++) {
|
|
val v = fr.variables(e);
|
|
|
|
val array = Nd4j.createFromFlatArray(v.ndarray());
|
|
res.addArgument(v.name(), array);
|
|
res.addArgument(v.id().first(), v.id().second(), array);
|
|
res.addArgument(v.name(), v.id().first(), v.id().second(), array);
|
|
}
|
|
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* This method sends inference request to the GraphServer instance, and returns result as array of INDArrays
|
|
* @param graphId id of the graph
|
|
* @param inputs graph inputs with their string ides
|
|
* @return
|
|
*/
|
|
public INDArray[] output(long graphId, Pair<String, INDArray>... inputs) {
|
|
val operands = new Operands();
|
|
for (val in:inputs)
|
|
operands.addArgument(in.getFirst(), in.getSecond());
|
|
|
|
return output(graphId, operands).asArray();
|
|
}
|
|
|
|
/**
|
|
* This method allows to remove graph from the GraphServer instance
|
|
* @param graphId
|
|
*/
|
|
public void dropGraph(long graphId) {
|
|
val builder = new FlatBufferBuilder(128);
|
|
|
|
val off = FlatDropRequest.createFlatDropRequest(builder, graphId);
|
|
builder.finish(off);
|
|
|
|
val req = FlatDropRequest.getRootAsFlatDropRequest(builder.dataBuffer());
|
|
|
|
val v = blockingStub.forgetGraph(req);
|
|
if (v.status() != 0)
|
|
throw new ND4JIllegalStateException("registerGraph() gRPC call failed");
|
|
}
|
|
}
|