cavis/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/graph/GraphInferenceGrpcClient.java
Alex Black 3f0b4a2d4c
SameDiff execution, TF and memory management overhaul (#10)
* SameDiff execution memory management improvements, round 1

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 3

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clear node outputs closed array references; Slight change to OpValidation internals to not rely on cached op outputs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next step

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add WeakIdentityHashmap

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Session fixes for control ops and next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First steps for training session + in-line updating

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix losses and history during training

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* BiasAdd and other fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Don't use SDVariable.getArr() in TFGraphTestAllHelper (import tests)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First steps for new dependency tracking approach

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Start integrating dependency tracking for memory management

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Non-control op dependency tracking works/passes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch/merge

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue dependency tracking for initial variables/constants

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add check for aliases when determining if safe to close array

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First pass on new TF graph import class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Import fixes, op fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and fixes for new TF import mapper

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Partial implementation of new dependency tracker

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* AbstractDependencyTracker for shared code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Overhaul SameDiff graph execution (dependency tracking)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes, cleanup, next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ad no-op memory manager, cleanup, fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix switch dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* INDArray.toString: no exception on closed arrays, just note closed

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix enter and exit dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* TensorArray memory management fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add unique ID for INDArray instances

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix memory management for NextIteration outputs in multi-iteration loops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove (now unnecessary) special case handling for nested enters

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Handle control dependencies during execution; javadoc for memory managers

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup, polish, code comments, javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and more javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add memory validation for all TF import tests - ensure all arrays (except outputs) are released

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clean up arrays waiting on unexecuted ops at the end of execution

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes for enter op memory managent in the context of multiple non-nested loops/frames

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix order of operation issues for dependency tracker

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Always clear op fields after execution to avoid leaks or unintended array reuse

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Re-implement dtype conversion

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for control dependencies execution (dependency tracking)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix TF import overrides and filtering

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for constant enter array dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More DL4J fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More polish and javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More logging level tweaks, small DL4J fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix to DL4J SameDiffLayer

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix empty array deserialization, add extra deserialization checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers control dep serialization fixes; test serialization as part of all TF import tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Variable control dependencies serialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue with removing inputs for ops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers NDArray deserialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers NDArray deserialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Final cleanup/polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-10-23 21:19:50 +11:00

208 lines
7.3 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.graph;
import com.google.flatbuffers.FlatBufferBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.input.OperandsAdapter;
import org.nd4j.autodiff.execution.input.Operands;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.graph.grpc.GraphInferenceServerGrpc;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
/**
* This class is a wrapper over GraphServer gRPC complex
*
* @author raver119@gmail.com
*/
@Slf4j
public class GraphInferenceGrpcClient {
private final ManagedChannel channel;
private final GraphInferenceServerGrpc.GraphInferenceServerBlockingStub blockingStub;
/**
* This method creates new GraphInferenceGrpcClient, with plain text connection
* @param host
* @param port
*/
public GraphInferenceGrpcClient(@NonNull String host, int port) {
this(host, port, false);
}
/**
* This method creates new GraphInferenceGrpcClient, with optional TLS support
* @param host
* @param port
*/
public GraphInferenceGrpcClient(@NonNull String host, int port, boolean useTLS) {
this(useTLS ? ManagedChannelBuilder.forAddress(host, port).build()
: ManagedChannelBuilder.forAddress(host, port).usePlaintext().build());
}
/**
* This method creates new GraphInferenceGrpcClient over given ManagedChannel
* @param channel
*/
public GraphInferenceGrpcClient(@NonNull ManagedChannel channel) {
this.channel = channel;
this.blockingStub = GraphInferenceServerGrpc.newBlockingStub(this.channel);
}
/**
* This method shuts down gRPC connection
*
* @throws InterruptedException
*/
public void shutdown() throws InterruptedException {
this.channel.shutdown().awaitTermination(10, TimeUnit.SECONDS);
}
/**
* This method adds given graph to the GraphServer storage
* @param graph
*/
public void registerGraph(@NonNull SameDiff graph) {
blockingStub.registerGraph(graph.asFlatGraph(false));
}
/**
* This method adds given graph to the GraphServer storage
*
* PLEASE NOTE: You don't need to register graph more then once
* PLEASE NOTE: You don't need to register graph if GraphServer was used with -f argument
* @param graphId id of the graph, if not 0 - should be used in subsequent output() requests
* @param graph
*
*/
public void registerGraph(long graphId, @NonNull SameDiff graph, ExecutorConfiguration configuration) {
val g = graph.asFlatGraph(graphId, configuration, false);
val v = blockingStub.registerGraph(g);
if (v.status() != 0)
throw new ND4JIllegalStateException("registerGraph() gRPC call failed");
}
/**
* This method sends inference request to the GraphServer instance, and returns result as array of INDArrays
*
* PLEASE NOTE: This call will be routed to default graph with id 0
* @param inputs graph inputs with their string ides
* @return
*/
public INDArray[] output(Pair<String, INDArray>... inputs) {
return output(0, inputs);
}
/**
* This method is suited for use of custom OperandsAdapters
* @param adapter
* @param <T>
* @return
*/
public <T> T output(long graphId, T value, OperandsAdapter<T> adapter) {
return adapter.output(this.output(graphId, adapter.input(value)));
}
public Operands output(long graphId, @NonNull Operands operands) {
val result = new ArrayList<INDArray>();
val builder = new FlatBufferBuilder(1024);
val ins = new int[operands.size()];
val col = operands.asCollection();
int cnt = 0;
for (val input: col) {
val id = input.getFirst();
val array = input.getSecond();
val idPair = IntPair.createIntPair(builder, id.getId(), id.getIndex());
val nameOff = id.getName() != null ? builder.createString(id.getName()) : 0;
val arrOff = array.toFlatArray(builder);
byte variableType = 0; //TODO is this OK here?
val varOff = FlatVariable.createFlatVariable(builder, idPair, nameOff, FlatBuffersMapper.getDataTypeAsByte(array.dataType()),0, arrOff, -1, variableType, 0, 0, 0);
ins[cnt++] = varOff;
}
val varsOff = FlatInferenceRequest.createVariablesVector(builder, ins);
val off = FlatInferenceRequest.createFlatInferenceRequest(builder, graphId, varsOff, 0);
builder.finish(off);
val req = FlatInferenceRequest.getRootAsFlatInferenceRequest(builder.dataBuffer());
val fr = blockingStub.inferenceRequest(req);
val res = new Operands();
for (int e = 0; e < fr.variablesLength(); e++) {
val v = fr.variables(e);
val array = Nd4j.createFromFlatArray(v.ndarray());
res.addArgument(v.name(), array);
res.addArgument(v.id().first(), v.id().second(), array);
res.addArgument(v.name(), v.id().first(), v.id().second(), array);
}
return res;
}
/**
* This method sends inference request to the GraphServer instance, and returns result as array of INDArrays
* @param graphId id of the graph
* @param inputs graph inputs with their string ides
* @return
*/
public INDArray[] output(long graphId, Pair<String, INDArray>... inputs) {
val operands = new Operands();
for (val in:inputs)
operands.addArgument(in.getFirst(), in.getSecond());
return output(graphId, operands).asArray();
}
/**
* This method allows to remove graph from the GraphServer instance
* @param graphId
*/
public void dropGraph(long graphId) {
val builder = new FlatBufferBuilder(128);
val off = FlatDropRequest.createFlatDropRequest(builder, graphId);
builder.finish(off);
val req = FlatDropRequest.getRootAsFlatDropRequest(builder.dataBuffer());
val v = blockingStub.forgetGraph(req);
if (v.status() != 0)
throw new ND4JIllegalStateException("registerGraph() gRPC call failed");
}
}