2019-06-06 15:21:15 +03:00
|
|
|
/*******************************************************************************
|
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
|
*
|
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
*
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
|
* under the License.
|
|
|
|
|
*
|
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
package org.deeplearning4j.nn.workspace;
|
|
|
|
|
|
Version upgrades (#199)
* DataVec fixes for Jackson version upgrade
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J jackson updates + databind version 2.9.9.3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Shade snakeyaml along with jackson
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Version fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Switch DataVec legacy JSON format handling to mixins
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Next set of fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Cleanup for legacy JSON mapping
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade commons compress to 1.18; small test fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* New Jackson backward compatibility for DL4J - Round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* New Jackson backward compatibility for DL4J - Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* More fixes, all but legacy custom passing
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Legacy deserialization cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small amount of polish - legacy JSON
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade guava version
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* IEvaluation legacy format deserialization fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade play version to 2.7.3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Update nd4j-parameter-server-status to new Play API
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Update DL4J UI for new play version
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* More play framework updates
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Remove Spark 1/2 adapter code from DataVec
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* datavec-spark dependency cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Test fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Another fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Breeze upgrade, dependency cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Add Scala 2.12 version to pom.xml
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* change-scala-versions.sh - add scala 2.12, remove 2.10
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Move Spark version properties to parent pom (now that only one spark version is supported)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DataVec Play fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* datavec play dependency fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Clean up old spark/jackson stuff
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Cleanup jackson unused dependencies
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Dropping redundant dependency
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* Removed scalaxy dependency
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* DataVec fixes for Jackson version upgrade
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J jackson updates + databind version 2.9.9.3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Shade snakeyaml along with jackson
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Version fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Switch DataVec legacy JSON format handling to mixins
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Next set of fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Cleanup for legacy JSON mapping
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade commons compress to 1.18; small test fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* New Jackson backward compatibility for DL4J - Round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* New Jackson backward compatibility for DL4J - Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* More fixes, all but legacy custom passing
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Provide an upgrade path for custom layers for models in pre-1.0.0-beta JSON format
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Legacy deserialization cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small amount of polish - legacy JSON
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade guava version
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* IEvaluation legacy format deserialization fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Upgrade play version to 2.7.3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Update nd4j-parameter-server-status to new Play API
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Update DL4J UI for new play version
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* More play framework updates
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Remove Spark 1/2 adapter code from DataVec
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* datavec-spark dependency cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DL4J spark updates, pt 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Test fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Another fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Breeze upgrade, dependency cleanup
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Add Scala 2.12 version to pom.xml
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* change-scala-versions.sh - add scala 2.12, remove 2.10
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Move Spark version properties to parent pom (now that only one spark version is supported)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DataVec Play fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* datavec play dependency fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Clean up old spark/jackson stuff
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Cleanup jackson unused dependencies
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Add shaded guava
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Dropping redundant dependency
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* Removed scalaxy dependency
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* Ensure not possible to import pre-shaded classes, and remove direct guava dependencies in favor of shaded
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* ND4J Shaded guava import fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* DataVec and DL4J guava shading
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Arbiter, RL4J fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Build fixed
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* Fix dependency
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* Fix bad merge
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Jackson shading fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Set play secret, datavec-spark-inference-server
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for datavec-spark-inference-server
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Arbiter fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Arbiter fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small test fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-30 14:35:27 +10:00
|
|
|
import org.nd4j.shade.guava.base.Preconditions;
|
2019-06-06 15:21:15 +03:00
|
|
|
import lombok.Getter;
|
|
|
|
|
import lombok.NonNull;
|
|
|
|
|
import lombok.Setter;
|
|
|
|
|
import org.bytedeco.javacpp.Pointer;
|
|
|
|
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
|
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
|
import org.nd4j.linalg.workspace.BaseWorkspaceMgr;
|
|
|
|
|
import org.nd4j.linalg.workspace.WorkspaceMgr;
|
|
|
|
|
|
|
|
|
|
import java.util.*;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* {@link WorkspaceMgr} for DL4J layers.
|
|
|
|
|
* Used to flexibly specify which workspaces a given array type (defined by {@link ArrayType}) should be placed in
|
|
|
|
|
*
|
|
|
|
|
* @author Alex Black
|
|
|
|
|
*/
|
|
|
|
|
public class LayerWorkspaceMgr extends BaseWorkspaceMgr<ArrayType> {
|
|
|
|
|
public static String CUDNN_WORKSPACE_KEY = "CUDNN_WORKSPACE";
|
|
|
|
|
|
|
|
|
|
private static LayerWorkspaceMgr NO_WS_IMMUTABLE;
|
|
|
|
|
static{
|
|
|
|
|
Set<ArrayType> all = new HashSet<>();
|
|
|
|
|
Collections.addAll(all, ArrayType.values());
|
|
|
|
|
NO_WS_IMMUTABLE = new LayerWorkspaceMgr(
|
|
|
|
|
all, Collections.<ArrayType, WorkspaceConfiguration>emptyMap(), Collections.<ArrayType, String>emptyMap());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected Set<String> noLeverageOverride;
|
|
|
|
|
|
|
|
|
|
@Setter @Getter
|
|
|
|
|
protected Map<String,Pointer> helperWorkspacePointers;
|
|
|
|
|
|
|
|
|
|
private LayerWorkspaceMgr(){
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public LayerWorkspaceMgr(Set<ArrayType> scopeOutOfWs, Map<ArrayType, WorkspaceConfiguration> configMap,
|
|
|
|
|
Map<ArrayType, String> workspaceNames){
|
|
|
|
|
super(scopeOutOfWs, configMap, workspaceNames);
|
|
|
|
|
if(configMap != null){
|
|
|
|
|
Preconditions.checkArgument(configMap.keySet().equals(workspaceNames.keySet()),
|
|
|
|
|
"Keys for config may and workspace names must match");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void setNoLeverageOverride(String wsName){
|
|
|
|
|
if(noLeverageOverride == null){
|
|
|
|
|
noLeverageOverride = new HashSet<>();
|
|
|
|
|
}
|
|
|
|
|
noLeverageOverride.add(wsName);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public INDArray leverageTo(ArrayType arrayType, INDArray array){
|
|
|
|
|
if(noLeverageOverride != null && array.isAttached() && noLeverageOverride.contains(array.data().getParentWorkspace().getId())){
|
|
|
|
|
return array;
|
|
|
|
|
}
|
|
|
|
|
return super.leverageTo(arrayType, array);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public INDArray validateArrayLocation(@NonNull ArrayType arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) {
|
|
|
|
|
if(noLeverageOverride != null && array.isAttached() && noLeverageOverride.contains(array.data().getParentWorkspace().getId())){
|
|
|
|
|
return array; //OK - leverage override
|
|
|
|
|
}
|
|
|
|
|
return super.validateArrayLocation(arrayType, array, migrateIfInvalid, exceptionIfDetached);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Get the pointer to the helper memory. Usually used for CUDNN workspace memory sharing.
|
|
|
|
|
* NOTE: Don't use this method unless you are fully aware of how it is used to manage CuDNN memory!
|
|
|
|
|
* Will (by design) throw a NPE if the underlying map (set from MultiLayerNetwork or ComputationGraph) is not set.
|
|
|
|
|
*
|
|
|
|
|
* @param key Key for the helper workspace pointer
|
|
|
|
|
* @param <T> Pointer type
|
|
|
|
|
* @return Pointer for that key, or null if none exists
|
|
|
|
|
*/
|
|
|
|
|
public <T extends Pointer> T getHelperWorkspace(String key){
|
|
|
|
|
return helperWorkspacePointers == null ? null : (T)helperWorkspacePointers.get(key);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Set the pointer to the helper memory. Usually used for CuDNN workspace memory sharing.
|
|
|
|
|
* NOTE: Don't use this method unless you are fully aware of how it is used to manage CuDNN memory!
|
|
|
|
|
* Will (by design) throw a NPE if the underlying map (set from MultiLayerNetwork or ComputationGraph) is not set.
|
|
|
|
|
*
|
|
|
|
|
* @param key Key for the helper workspace pointer
|
|
|
|
|
* @param value Pointer
|
|
|
|
|
*/
|
|
|
|
|
public void setHelperWorkspace(@NonNull String key, Pointer value){
|
|
|
|
|
if(helperWorkspacePointers == null){
|
|
|
|
|
helperWorkspacePointers = new HashMap<>();
|
|
|
|
|
}
|
|
|
|
|
helperWorkspacePointers.put(key, value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static Builder builder(){
|
|
|
|
|
return new Builder();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @param helperWorkspacePointers Helper pointers - see {@link #getHelperWorkspace(String)} for details
|
|
|
|
|
* @return Workspace manager
|
|
|
|
|
*/
|
|
|
|
|
public static LayerWorkspaceMgr noWorkspaces(Map<String,Pointer> helperWorkspacePointers){
|
|
|
|
|
LayerWorkspaceMgr wsm = noWorkspaces();
|
|
|
|
|
wsm.setHelperWorkspacePointers(helperWorkspacePointers);
|
|
|
|
|
return wsm;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static LayerWorkspaceMgr noWorkspaces(){
|
|
|
|
|
return builder().defaultNoWorkspace().build();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static LayerWorkspaceMgr noWorkspacesImmutable(){
|
|
|
|
|
return NO_WS_IMMUTABLE;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public static class Builder {
|
|
|
|
|
|
|
|
|
|
private LayerWorkspaceMgr mgr;
|
|
|
|
|
|
|
|
|
|
public Builder(){
|
|
|
|
|
mgr = new LayerWorkspaceMgr();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Set the default to be scoped out for all array types.
|
|
|
|
|
* NOTE: Will not override the configuration for any array types that have already been configured
|
|
|
|
|
* @return Builder
|
|
|
|
|
*/
|
|
|
|
|
public Builder defaultNoWorkspace(){
|
|
|
|
|
for(ArrayType t : ArrayType.values()){
|
|
|
|
|
if(!mgr.configMap.containsKey(t)){
|
|
|
|
|
mgr.setScopedOutFor(t);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Specify that no workspace should be used for array of the specified type - i.e., these arrays should all
|
|
|
|
|
* be scoped out.
|
|
|
|
|
*
|
|
|
|
|
* @param type Array type to set scoped out for
|
|
|
|
|
* @return Builder
|
|
|
|
|
*/
|
|
|
|
|
public Builder noWorkspaceFor(ArrayType type){
|
|
|
|
|
mgr.setScopedOutFor(type);
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Set the default workspace for all array types to the specified workspace name/configuration
|
|
|
|
|
* NOTE: This will NOT override any settings previously set.
|
|
|
|
|
*
|
|
|
|
|
* @param workspaceName Name of the workspace to use for all (not set) arrray types
|
|
|
|
|
* @param configuration Configuration to use for all (not set) arrray types
|
|
|
|
|
* @return Builder
|
|
|
|
|
*/
|
|
|
|
|
public Builder defaultWorkspace(String workspaceName, WorkspaceConfiguration configuration){
|
|
|
|
|
for(ArrayType t : ArrayType.values()){
|
|
|
|
|
if(!mgr.configMap.containsKey(t) && !mgr.isScopedOut(t)){
|
|
|
|
|
with(t, workspaceName, configuration);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Configure the workspace (name, configuration) for the specified array type
|
|
|
|
|
*
|
|
|
|
|
* @param type Array type
|
|
|
|
|
* @param workspaceName Workspace name for the specified array type
|
|
|
|
|
* @param configuration Configuration for the specified array type
|
|
|
|
|
* @return Builder
|
|
|
|
|
*/
|
|
|
|
|
public Builder with(ArrayType type, String workspaceName, WorkspaceConfiguration configuration){
|
|
|
|
|
mgr.setConfiguration(type, configuration);
|
|
|
|
|
mgr.setWorkspaceName(type, workspaceName);
|
|
|
|
|
return this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public LayerWorkspaceMgr build(){
|
|
|
|
|
return mgr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|