211 lines
7.7 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.workspace;
Version upgrades (#199) * DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next set of fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack <blacka101@gmail.com> * Legacy deserialization cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade guava version Signed-off-by: AlexDBlack <blacka101@gmail.com> * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update DL4J UI for new play version Signed-off-by: AlexDBlack <blacka101@gmail.com> * More play framework updates Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec-spark dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Another fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack <blacka101@gmail.com> * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec Play fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec play dependency fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack <blacka101@gmail.com> * Dropping redundant dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * DataVec fixes for Jackson version upgrade Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J jackson updates + databind version 2.9.9.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Shade snakeyaml along with jackson Signed-off-by: AlexDBlack <blacka101@gmail.com> * Version fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch DataVec legacy JSON format handling to mixins Signed-off-by: AlexDBlack <blacka101@gmail.com> * Next set of fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup for legacy JSON mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade commons compress to 1.18; small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * New Jackson backward compatibility for DL4J - Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * More fixes, all but legacy custom passing Signed-off-by: AlexDBlack <blacka101@gmail.com> * Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format Signed-off-by: AlexDBlack <blacka101@gmail.com> * Legacy deserialization cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small amount of polish - legacy JSON Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade guava version Signed-off-by: AlexDBlack <blacka101@gmail.com> * IEvaluation legacy format deserialization fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Upgrade play version to 2.7.3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update nd4j-parameter-server-status to new Play API Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update DL4J UI for new play version Signed-off-by: AlexDBlack <blacka101@gmail.com> * More play framework updates Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove Spark 1/2 adapter code from DataVec Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec-spark dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J spark updates, pt 4 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Another fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Breeze upgrade, dependency cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add Scala 2.12 version to pom.xml Signed-off-by: AlexDBlack <blacka101@gmail.com> * change-scala-versions.sh - add scala 2.12, remove 2.10 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Move Spark version properties to parent pom (now that only one spark version is supported) Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec Play fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec play dependency fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up old spark/jackson stuff Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup jackson unused dependencies Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add shaded guava Signed-off-by: AlexDBlack <blacka101@gmail.com> * Dropping redundant dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Removed scalaxy dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Ensure not possible to import pre-shaded classes, and remove direct guava dependencies in favor of shaded Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J Shaded guava import fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataVec and DL4J guava shading Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter, RL4J fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Build fixed Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fix dependency Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Fix bad merge Signed-off-by: AlexDBlack <blacka101@gmail.com> * Jackson shading fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Set play secret, datavec-spark-inference-server Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for datavec-spark-inference-server Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-30 14:35:27 +10:00
import org.nd4j.shade.guava.base.Preconditions;
2019-06-06 15:21:15 +03:00
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.workspace.BaseWorkspaceMgr;
import org.nd4j.linalg.workspace.WorkspaceMgr;
import java.util.*;
/**
* {@link WorkspaceMgr} for DL4J layers.
* Used to flexibly specify which workspaces a given array type (defined by {@link ArrayType}) should be placed in
*
* @author Alex Black
*/
public class LayerWorkspaceMgr extends BaseWorkspaceMgr<ArrayType> {
public static String CUDNN_WORKSPACE_KEY = "CUDNN_WORKSPACE";
private static LayerWorkspaceMgr NO_WS_IMMUTABLE;
static{
Set<ArrayType> all = new HashSet<>();
Collections.addAll(all, ArrayType.values());
NO_WS_IMMUTABLE = new LayerWorkspaceMgr(
all, Collections.<ArrayType, WorkspaceConfiguration>emptyMap(), Collections.<ArrayType, String>emptyMap());
}
protected Set<String> noLeverageOverride;
@Setter @Getter
protected Map<String,Pointer> helperWorkspacePointers;
private LayerWorkspaceMgr(){
}
public LayerWorkspaceMgr(Set<ArrayType> scopeOutOfWs, Map<ArrayType, WorkspaceConfiguration> configMap,
Map<ArrayType, String> workspaceNames){
super(scopeOutOfWs, configMap, workspaceNames);
if(configMap != null){
Preconditions.checkArgument(configMap.keySet().equals(workspaceNames.keySet()),
"Keys for config may and workspace names must match");
}
}
public void setNoLeverageOverride(String wsName){
if(noLeverageOverride == null){
noLeverageOverride = new HashSet<>();
}
noLeverageOverride.add(wsName);
}
@Override
public INDArray leverageTo(ArrayType arrayType, INDArray array){
if(noLeverageOverride != null && array.isAttached() && noLeverageOverride.contains(array.data().getParentWorkspace().getId())){
return array;
}
return super.leverageTo(arrayType, array);
}
@Override
public INDArray validateArrayLocation(@NonNull ArrayType arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
if(noLeverageOverride != null && array.isAttached() && noLeverageOverride.contains(array.data().getParentWorkspace().getId())){
return array; //OK - leverage override
}
return super.validateArrayLocation(arrayType, array, migrateIfInvalid, exceptionIfDetached);
}
/**
* Get the pointer to the helper memory. Usually used for CUDNN workspace memory sharing.
* NOTE: Don't use this method unless you are fully aware of how it is used to manage CuDNN memory!
* Will (by design) throw a NPE if the underlying map (set from MultiLayerNetwork or ComputationGraph) is not set.
*
* @param key Key for the helper workspace pointer
* @param <T> Pointer type
* @return Pointer for that key, or null if none exists
*/
public <T extends Pointer> T getHelperWorkspace(String key){
return helperWorkspacePointers == null ? null : (T)helperWorkspacePointers.get(key);
}
/**
* Set the pointer to the helper memory. Usually used for CuDNN workspace memory sharing.
* NOTE: Don't use this method unless you are fully aware of how it is used to manage CuDNN memory!
* Will (by design) throw a NPE if the underlying map (set from MultiLayerNetwork or ComputationGraph) is not set.
*
* @param key Key for the helper workspace pointer
* @param value Pointer
*/
public void setHelperWorkspace(@NonNull String key, Pointer value){
if(helperWorkspacePointers == null){
helperWorkspacePointers = new HashMap<>();
}
helperWorkspacePointers.put(key, value);
}
public static Builder builder(){
return new Builder();
}
/**
* @param helperWorkspacePointers Helper pointers - see {@link #getHelperWorkspace(String)} for details
* @return Workspace manager
*/
public static LayerWorkspaceMgr noWorkspaces(Map<String,Pointer> helperWorkspacePointers){
LayerWorkspaceMgr wsm = noWorkspaces();
wsm.setHelperWorkspacePointers(helperWorkspacePointers);
return wsm;
}
public static LayerWorkspaceMgr noWorkspaces(){
return builder().defaultNoWorkspace().build();
}
public static LayerWorkspaceMgr noWorkspacesImmutable(){
return NO_WS_IMMUTABLE;
}
public static class Builder {
private LayerWorkspaceMgr mgr;
public Builder(){
mgr = new LayerWorkspaceMgr();
}
/**
* Set the default to be scoped out for all array types.
* NOTE: Will not override the configuration for any array types that have already been configured
* @return Builder
*/
public Builder defaultNoWorkspace(){
for(ArrayType t : ArrayType.values()){
if(!mgr.configMap.containsKey(t)){
mgr.setScopedOutFor(t);
}
}
return this;
}
/**
* Specify that no workspace should be used for array of the specified type - i.e., these arrays should all
* be scoped out.
*
* @param type Array type to set scoped out for
* @return Builder
*/
public Builder noWorkspaceFor(ArrayType type){
mgr.setScopedOutFor(type);
return this;
}
/**
* Set the default workspace for all array types to the specified workspace name/configuration
* NOTE: This will NOT override any settings previously set.
*
* @param workspaceName Name of the workspace to use for all (not set) arrray types
* @param configuration Configuration to use for all (not set) arrray types
* @return Builder
*/
public Builder defaultWorkspace(String workspaceName, WorkspaceConfiguration configuration){
for(ArrayType t : ArrayType.values()){
if(!mgr.configMap.containsKey(t) && !mgr.isScopedOut(t)){
with(t, workspaceName, configuration);
}
}
return this;
}
/**
* Configure the workspace (name, configuration) for the specified array type
*
* @param type Array type
* @param workspaceName Workspace name for the specified array type
* @param configuration Configuration for the specified array type
* @return Builder
*/
public Builder with(ArrayType type, String workspaceName, WorkspaceConfiguration configuration){
mgr.setConfiguration(type, configuration);
mgr.setWorkspaceName(type, workspaceName);
return this;
}
public LayerWorkspaceMgr build(){
return mgr;
}
}
}