commit
1a6ada0ce9
|
@ -0,0 +1,151 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.nn.misc;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.nn.api.Updater;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class CloseNetworkTests extends BaseDL4JTest {
|
||||||
|
|
||||||
|
public static MultiLayerNetwork getTestNet() {
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.updater(new Adam(1e-3))
|
||||||
|
.list()
|
||||||
|
.layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(3, 3).activation(Activation.TANH).build())
|
||||||
|
.layer(new BatchNormalization.Builder().nOut(5).build())
|
||||||
|
.layer(new SubsamplingLayer.Builder().build())
|
||||||
|
.layer(new DenseLayer.Builder().nOut(10).activation(Activation.RELU).build())
|
||||||
|
.layer(new OutputLayer.Builder().nOut(10).build())
|
||||||
|
.setInputType(InputType.convolutional(28, 28, 1))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
return net;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCloseMLN() {
|
||||||
|
for (boolean train : new boolean[]{false, true}) {
|
||||||
|
for (boolean test : new boolean[]{false, true}) {
|
||||||
|
MultiLayerNetwork net = getTestNet();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.FLOAT, 16, 1, 28, 28);
|
||||||
|
INDArray l = TestUtils.randomOneHot(16, 10);
|
||||||
|
|
||||||
|
if (train) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
net.fit(f, l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (test) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
net.output(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
net.close();
|
||||||
|
|
||||||
|
assertTrue(net.params().wasClosed());
|
||||||
|
if(train) {
|
||||||
|
assertTrue(net.getGradientsViewArray().wasClosed());
|
||||||
|
Updater u = net.getUpdater(false);
|
||||||
|
assertTrue(u.getStateViewArray().wasClosed());
|
||||||
|
}
|
||||||
|
|
||||||
|
//Make sure we don't get crashes etc when trying to use after closing
|
||||||
|
try {
|
||||||
|
net.output(f);
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg, msg.contains("released"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
net.fit(f, l);
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg, msg.contains("released"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCloseCG() {
|
||||||
|
for (boolean train : new boolean[]{false, true}) {
|
||||||
|
for (boolean test : new boolean[]{false, true}) {
|
||||||
|
ComputationGraph net = getTestNet().toComputationGraph();
|
||||||
|
|
||||||
|
INDArray f = Nd4j.rand(DataType.FLOAT, 16, 1, 28, 28);
|
||||||
|
INDArray l = TestUtils.randomOneHot(16, 10);
|
||||||
|
|
||||||
|
if (train) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
net.fit(new INDArray[]{f}, new INDArray[]{l});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (test) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
net.output(f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
net.close();
|
||||||
|
|
||||||
|
assertTrue(net.params().wasClosed());
|
||||||
|
if(train) {
|
||||||
|
assertTrue(net.getGradientsViewArray().wasClosed());
|
||||||
|
Updater u = net.getUpdater(false);
|
||||||
|
assertTrue(u.getStateViewArray().wasClosed());
|
||||||
|
}
|
||||||
|
|
||||||
|
//Make sure we don't get crashes etc when trying to use after closing
|
||||||
|
try {
|
||||||
|
net.output(f);
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg, msg.contains("released"));
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
net.fit(new INDArray[]{f}, new INDArray[]{l});
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg, msg.contains("released"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1035,5 +1035,9 @@ public class TestOptimizers extends BaseDL4JTest {
|
||||||
public boolean updaterDivideByMinibatch(String paramName) {
|
public boolean updaterDivideByMinibatch(String paramName) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1055,4 +1055,9 @@ public class BarnesHutTsne implements Model {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//No-op
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,8 +128,6 @@ public class KerasInput extends KerasLayer {
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
if(this.dimOrder != null) {
|
if(this.dimOrder != null) {
|
||||||
System.out.println("Dim order: " + this.dimOrder);
|
|
||||||
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
|
|
||||||
switch (this.dimOrder) {
|
switch (this.dimOrder) {
|
||||||
case TENSORFLOW: //NWC == channels_last
|
case TENSORFLOW: //NWC == channels_last
|
||||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||||
|
|
|
@ -103,7 +103,7 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
String dtype = inputDataTypes.get(inpName);
|
String dtype = inputDataTypes.get(inpName);
|
||||||
graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
|
graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
|
||||||
}
|
}
|
||||||
log.info(graph);
|
//log.info(graph);
|
||||||
GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
|
GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
|
||||||
TextFormat.getParser().merge(graph, graphDefBuilder);
|
TextFormat.getParser().merge(graph, graphDefBuilder);
|
||||||
GraphDef graphDef = graphDefBuilder.build();
|
GraphDef graphDef = graphDefBuilder.build();
|
||||||
|
|
|
@ -63,10 +63,10 @@ public class KerasReLU extends KerasLayer {
|
||||||
double negativeSlope = 0.0;
|
double negativeSlope = 0.0;
|
||||||
double threshold = 0.0;
|
double threshold = 0.0;
|
||||||
if (innerConfig.containsKey("negative_slope")) {
|
if (innerConfig.containsKey("negative_slope")) {
|
||||||
negativeSlope = (double) innerConfig.get("negative_slope");
|
negativeSlope = ((Number)innerConfig.get("negative_slope")).doubleValue();
|
||||||
}
|
}
|
||||||
if (innerConfig.containsKey("threshold")) {
|
if (innerConfig.containsKey("threshold")) {
|
||||||
threshold = (double) innerConfig.get("threshold");
|
threshold = ((Number)innerConfig.get("threshold")).doubleValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.layer = new ActivationLayer.Builder().name(this.layerName)
|
this.layer = new ActivationLayer.Builder().name(this.layerName)
|
||||||
|
|
|
@ -95,7 +95,6 @@ public class KerasConvolution2D extends KerasConvolution {
|
||||||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||||
|
|
||||||
System.out.println("----" + dimOrder);
|
|
||||||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@ -351,6 +352,10 @@ public class KerasBatchNormalization extends KerasLayer {
|
||||||
private int getBatchNormAxis(Map<String, Object> layerConfig)
|
private int getBatchNormAxis(Map<String, Object> layerConfig)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
return (int) innerConfig.get(LAYER_FIELD_AXIS);
|
Object batchNormAxis = innerConfig.get(LAYER_FIELD_AXIS);
|
||||||
|
if (batchNormAxis instanceof List){
|
||||||
|
return ((Number)((List)batchNormAxis).get(0)).intValue();
|
||||||
|
}
|
||||||
|
return ((Number)innerConfig.get(LAYER_FIELD_AXIS)).intValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras;
|
||||||
|
|
||||||
|
public class Temp {
|
||||||
|
}
|
|
@ -233,4 +233,7 @@ public interface Model {
|
||||||
* Apply any constraints to the model
|
* Apply any constraints to the model
|
||||||
*/
|
*/
|
||||||
void applyConstraints(int iteration, int epoch);
|
void applyConstraints(int iteration, int epoch);
|
||||||
|
|
||||||
|
|
||||||
|
void close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -4824,4 +4824,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null)
|
if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null)
|
||||||
this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray());
|
this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Close the network and deallocate all native memory, including: parameters, gradients, updater memory and workspaces
|
||||||
|
* Note that the network should not be used again for any purpose after it has been closed
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//Close the INDArray and dealloc
|
||||||
|
if(flattenedParams.closeable())
|
||||||
|
flattenedParams.close();
|
||||||
|
|
||||||
|
if(flattenedGradients != null && flattenedGradients.closeable())
|
||||||
|
flattenedGradients.close();
|
||||||
|
|
||||||
|
Updater u = getUpdater(false);
|
||||||
|
if(u != null && u.getStateViewArray() != null) {
|
||||||
|
INDArray state = u.getStateViewArray();
|
||||||
|
if(state.closeable())
|
||||||
|
state.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
|
System.gc();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -428,4 +428,9 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
|
||||||
//Majority of params's gradients should be... Exception: batch norm mean/variance estimate
|
//Majority of params's gradients should be... Exception: batch norm mean/variance estimate
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//No-op for individual layers
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -599,4 +599,9 @@ public class BidirectionalLayer implements RecurrentLayer {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//No-op for individual layers
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1144,4 +1144,9 @@ public class VariationalAutoencoder implements Layer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//No-op for individual layers
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -329,4 +329,9 @@ public abstract class BaseWrapperLayer implements Layer {
|
||||||
public boolean updaterDivideByMinibatch(String paramName) {
|
public boolean updaterDivideByMinibatch(String paramName) {
|
||||||
return underlying.updaterDivideByMinibatch(paramName);
|
return underlying.updaterDivideByMinibatch(paramName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//No-op for individual layers
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4085,4 +4085,27 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray());
|
this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Close the network and deallocate all native memory, including: parameters, gradients, updater memory and workspaces
|
||||||
|
* Note that the network should not be used again for any purpose after it has been closed
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
//Close the INDArray and dealloc
|
||||||
|
if(flattenedParams.closeable())
|
||||||
|
flattenedParams.close();
|
||||||
|
|
||||||
|
if(flattenedGradients != null && flattenedGradients.closeable())
|
||||||
|
flattenedGradients.close();
|
||||||
|
|
||||||
|
Updater u = getUpdater(false);
|
||||||
|
if(u != null && u.getStateViewArray() != null) {
|
||||||
|
INDArray state = u.getStateViewArray();
|
||||||
|
if(state.closeable())
|
||||||
|
state.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
|
System.gc();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@ import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
import org.deeplearning4j.core.storage.StatsStorageRouter;
|
||||||
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
|
import org.deeplearning4j.core.storage.listener.RoutingIterationListener;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
||||||
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
||||||
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator;
|
||||||
|
@ -688,6 +690,7 @@ public class ParallelWrapper implements AutoCloseable {
|
||||||
protected Supplier<INDArray> updaterParamsSupplier;
|
protected Supplier<INDArray> updaterParamsSupplier;
|
||||||
protected ThresholdAlgorithm thresholdAlgorithm;
|
protected ThresholdAlgorithm thresholdAlgorithm;
|
||||||
protected ResidualPostProcessor residualPostProcessor;
|
protected ResidualPostProcessor residualPostProcessor;
|
||||||
|
protected Long encoderMemory = -1L;
|
||||||
|
|
||||||
protected GradientsAccumulator accumulator;
|
protected GradientsAccumulator accumulator;
|
||||||
|
|
||||||
|
@ -872,6 +875,19 @@ public class ParallelWrapper implements AutoCloseable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method allows to define amount of temporary memory that will be used for gradients sharing.
|
||||||
|
* Typically it's safe to keep default value.
|
||||||
|
*
|
||||||
|
* Default value: -1, amount of temporary memory will be calculated automatically
|
||||||
|
* @param numBytes number of bytes to be used
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public Builder temporaryMemory(@NonNull Long numBytes) {
|
||||||
|
this.encoderMemory = numBytes;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the residual post processor algorithm. Not used for single machine training (only for PW used in a
|
* Set the residual post processor algorithm. Not used for single machine training (only for PW used in a
|
||||||
* distributed setting), and should not be set by users in most cases.
|
* distributed setting), and should not be set by users in most cases.
|
||||||
|
@ -907,11 +923,23 @@ public class ParallelWrapper implements AutoCloseable {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case SHARED_GRADIENTS: {
|
case SHARED_GRADIENTS: {
|
||||||
Preconditions.checkState(thresholdAlgorithm != null, "Cannot use SHARED_GRADIENTS training mode without setting a threshold algorithm");
|
if (thresholdAlgorithm == null)
|
||||||
|
thresholdAlgorithm = new AdaptiveThresholdAlgorithm();
|
||||||
|
|
||||||
this.trainerContext = new SymmetricTrainerContext();
|
this.trainerContext = new SymmetricTrainerContext();
|
||||||
if (this.accumulator == null) {
|
if (this.accumulator == null) {
|
||||||
log.info("Creating new GradientsAccumulator instance with default threshold of [5e-4]");
|
log.info("Creating new GradientsAccumulator instance with default threshold of [5e-4]");
|
||||||
this.accumulator = new EncodedGradientsAccumulator(workers, thresholdAlgorithm, residualPostProcessor, false);
|
val numParams = model.numParams();
|
||||||
|
|
||||||
|
// we're limiting max size of updates for Sparse encoding to the size of bitmap encoded message
|
||||||
|
val maxUpdate = (int) (numParams / 16 + 5);
|
||||||
|
|
||||||
|
// memory sie in number of bytes
|
||||||
|
long memorySize = encoderMemory == null || encoderMemory < 0
|
||||||
|
? maxUpdate * 4 * (workers + 3)
|
||||||
|
: encoderMemory;
|
||||||
|
|
||||||
|
this.accumulator = new EncodedGradientsAccumulator(workers, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, maxUpdate, false), memorySize, workers + 2, Integer.MAX_VALUE, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -450,6 +450,14 @@ public class DefaultTrainer extends Thread implements Trainer {
|
||||||
} finally {
|
} finally {
|
||||||
log.debug("Terminating all workspaces for trainer_{}", threadId);
|
log.debug("Terminating all workspaces for trainer_{}", threadId);
|
||||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
|
|
||||||
|
if (!onRootModel) {
|
||||||
|
replicatedModel.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
// let's try to enforce GC to actually clean all references now
|
||||||
|
replicatedModel.clear();
|
||||||
|
System.gc();
|
||||||
isStopped.set(true);
|
isStopped.set(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class ParallelWrapperTest extends BaseDL4JTest {
|
||||||
|
|
||||||
// for GPU you usually want to have higher batchSize
|
// for GPU you usually want to have higher batchSize
|
||||||
int batchSize = 128;
|
int batchSize = 128;
|
||||||
int nEpochs = 2;
|
int nEpochs = 5;
|
||||||
int seed = 123;
|
int seed = 123;
|
||||||
|
|
||||||
log.info("Load data....");
|
log.info("Load data....");
|
||||||
|
|
|
@ -163,14 +163,13 @@ namespace sd {
|
||||||
// here we just calculate number of sumBlock arrays
|
// here we just calculate number of sumBlock arrays
|
||||||
do {
|
do {
|
||||||
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
||||||
if (numBlocks > 1) {
|
if (numPrefixBlocks > 1) {
|
||||||
level++;
|
level++;
|
||||||
}
|
}
|
||||||
numElts = numPrefixBlocks;
|
numElts = numPrefixBlocks;
|
||||||
} while (numElts > 1);
|
} while (numElts > 1);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<NDArray> tempArrays(level);
|
std::vector<NDArray> tempArrays(level);
|
||||||
std::vector<Nd4jPointer> pointers(level);
|
std::vector<Nd4jPointer> pointers(level);
|
||||||
|
|
||||||
|
@ -181,13 +180,13 @@ namespace sd {
|
||||||
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
int numPrefixBlocks = sd::math::nd4j_max<int>(1, sd::math::nd4j_ceil<float, int>((float) numElts / (2.0f * prefixThreads)));
|
||||||
if (numPrefixBlocks > 1) {
|
if (numPrefixBlocks > 1) {
|
||||||
tempArrays[level] = std::move(NDArrayFactory::create<int>('c', {numPrefixBlocks}));
|
tempArrays[level] = std::move(NDArrayFactory::create<int>('c', {numPrefixBlocks}));
|
||||||
pointers[level] = tempArrays[level++].specialBuffer();
|
pointers[level] = tempArrays[level].specialBuffer();;
|
||||||
|
level++;
|
||||||
}
|
}
|
||||||
numElts = numPrefixBlocks;
|
numElts = numPrefixBlocks;
|
||||||
} while (numElts > 1);
|
} while (numElts > 1);
|
||||||
|
|
||||||
PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode");
|
PointersManager pm(LaunchContext::defaultContext(), "thresholdEncode");
|
||||||
auto dptr = pm.replicatePointer(pointers.data(), pointers.size() * 8);
|
|
||||||
auto offsets = NDArrayFactory::create<int>('c', {numBlocks});
|
auto offsets = NDArrayFactory::create<int>('c', {numBlocks});
|
||||||
|
|
||||||
// we want to check, if we're hiting external limit on number of encoded elements
|
// we want to check, if we're hiting external limit on number of encoded elements
|
||||||
|
@ -200,7 +199,7 @@ namespace sd {
|
||||||
NDArray::prepareSpecialUse({}, {&encoded, &updates});
|
NDArray::prepareSpecialUse({}, {&encoded, &updates});
|
||||||
|
|
||||||
// filling offsets
|
// filling offsets
|
||||||
encodeThresholdP2Int_(reinterpret_cast<void **>(dptr),
|
encodeThresholdP2Int_(reinterpret_cast<void **>(pointers.data()),
|
||||||
reinterpret_cast<int*>(blocks.specialBuffer()),
|
reinterpret_cast<int*>(blocks.specialBuffer()),
|
||||||
numBlocks,
|
numBlocks,
|
||||||
reinterpret_cast<int*>(offsets.specialBuffer()));
|
reinterpret_cast<int*>(offsets.specialBuffer()));
|
||||||
|
|
|
@ -228,6 +228,53 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
|
||||||
ASSERT_EQ(exp, initial);
|
ASSERT_EQ(exp, initial);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||||
|
// [2,1,135079944,1,1,8192,1,99]
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {1, 135079944});
|
||||||
|
initial = 1.0f;
|
||||||
|
auto exp = initial.dup();
|
||||||
|
auto neg = initial.like();
|
||||||
|
neg = 0.5f;
|
||||||
|
|
||||||
|
sd::ops::encode_threshold enc;
|
||||||
|
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||||
|
auto encoded = enc_result.at(1);
|
||||||
|
|
||||||
|
ASSERT_EQ(135079944 + 4, encoded->lengthOf());
|
||||||
|
ASSERT_NE(exp, initial);
|
||||||
|
/*
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 0.5f) {
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
throw std::runtime_error("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
ASSERT_EQ(neg, initial);
|
||||||
|
|
||||||
|
// checking equality of all encoded bits
|
||||||
|
//for (int e = 5; e < encoded->lengthOf() - 1; e++) {
|
||||||
|
//if (encoded->e<int>(e) != encoded->e<int>(e - 1) + 1)
|
||||||
|
//nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e<int>(e));
|
||||||
|
//}
|
||||||
|
|
||||||
|
sd::ops::decode_threshold dec;
|
||||||
|
auto status = dec.execute({&initial, encoded}, {&initial});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// checking equality of all dedoded bits
|
||||||
|
/*
|
||||||
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
auto f = initial.e<float>(e);
|
||||||
|
if (f != 1.0f)
|
||||||
|
nd4j_printf("initial[%i] = %f\n", e, f);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
|
@ -1957,6 +1957,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean wasClosed() {
|
public boolean wasClosed() {
|
||||||
|
if (wrappedDataBuffer != null && wrappedDataBuffer != this)
|
||||||
|
return wrappedDataBuffer.wasClosed();
|
||||||
|
|
||||||
return released;
|
return released;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
public INDArray castTo(DataType dataType) {
|
public INDArray castTo(DataType dataType) {
|
||||||
if(dataType == dataType()) //No-op if correct datatype
|
if(dataType == dataType()) //No-op if correct datatype
|
||||||
return this;
|
return this;
|
||||||
if(isEmpty()){
|
if(isEmpty() && rank() == 0){
|
||||||
return Nd4j.empty(dataType);
|
return Nd4j.empty(dataType);
|
||||||
}
|
}
|
||||||
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());
|
val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering());
|
||||||
|
|
|
@ -41,6 +41,12 @@ public class EncodeThreshold extends DynamicCustomOp {
|
||||||
this(updates, threshold, Integer.MAX_VALUE);
|
this(updates, threshold, Integer.MAX_VALUE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public EncodeThreshold(@NonNull INDArray updates, @NonNull INDArray encoded, float threshold, @NonNull Integer boundary) {
|
||||||
|
this(updates, threshold, boundary);
|
||||||
|
|
||||||
|
addOutputArgument(updates, encoded);
|
||||||
|
}
|
||||||
|
|
||||||
public EncodeThreshold(@NonNull INDArray updates, float threshold, @NonNull Integer boundary) {
|
public EncodeThreshold(@NonNull INDArray updates, float threshold, @NonNull Integer boundary) {
|
||||||
addInputArgument(updates);
|
addInputArgument(updates);
|
||||||
|
|
||||||
|
|
|
@ -692,9 +692,33 @@ public abstract class DefaultOpExecutioner implements OpExecutioner {
|
||||||
return thresholdEncode(input, threshold, Integer.MAX_VALUE);
|
return thresholdEncode(input, threshold, Integer.MAX_VALUE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private long _length(long[] shape) {
|
||||||
|
// scalar case
|
||||||
|
if (shape.length == 0)
|
||||||
|
return 1;
|
||||||
|
else if (shape.length == 1)
|
||||||
|
return shape[0];
|
||||||
|
else {
|
||||||
|
long length = 1;
|
||||||
|
for (int e = 0; e < shape.length; e++)
|
||||||
|
length *= shape[e];
|
||||||
|
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
|
public INDArray thresholdEncode(INDArray input, double threshold, Integer boundary) {
|
||||||
val result = Nd4j.exec(new EncodeThreshold(input, (float) threshold, boundary))[1];
|
val op_shape = new EncodeThreshold(input, (float) threshold, boundary);
|
||||||
|
val shapes = Nd4j.getExecutioner().calculateOutputShape(op_shape);
|
||||||
|
|
||||||
|
if (_length(shapes.get(1).getShape()) < 2)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
val result = Nd4j.create(DataType.INT32, shapes.get(1).getShape());
|
||||||
|
|
||||||
|
op_shape.addOutputArgument(input, result);
|
||||||
|
Nd4j.exec(op_shape);
|
||||||
|
|
||||||
return result.getInt(0) > 0 ? result : null;
|
return result.getInt(0) > 0 ? result : null;
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,13 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
return ptr;//allocator.getMemoryHandler().alloc(AllocationStatus.HOST, null, null, initialize).getHostPointer();
|
return ptr;//allocator.getMemoryHandler().alloc(AllocationStatus.HOST, null, null, initialize).getHostPointer();
|
||||||
} else if (kind == MemoryKind.DEVICE) {
|
} else if (kind == MemoryKind.DEVICE) {
|
||||||
val ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(bytes, 0, 0);
|
val ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(bytes, 0, 0);
|
||||||
//log.info("Allocating {} bytes for device_{}", bytes, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
log.trace("Allocating {} bytes for device_{}", bytes, Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||||
|
|
||||||
|
val ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode();
|
||||||
|
if (ec != 0) {
|
||||||
|
val em = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage();
|
||||||
|
throw new RuntimeException(em + "; Bytes: [" + bytes + "]; Error code [" + ec + "]; DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "]");
|
||||||
|
}
|
||||||
|
|
||||||
if (ptr == null)
|
if (ptr == null)
|
||||||
throw new RuntimeException("Failed to allocate " + bytes + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory");
|
throw new RuntimeException("Failed to allocate " + bytes + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory");
|
||||||
|
|
|
@ -191,7 +191,7 @@ public class CudaWorkspace extends Nd4jWorkspace {
|
||||||
// spill
|
// spill
|
||||||
if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
|
if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
|
||||||
//log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory);
|
//log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory);
|
||||||
reset();
|
deviceOffset.set(0);
|
||||||
resetPlanned.set(true);
|
resetPlanned.set(true);
|
||||||
return alloc(requiredMemory, kind, type, initialize);
|
return alloc(requiredMemory, kind, type, initialize);
|
||||||
}
|
}
|
||||||
|
@ -204,7 +204,6 @@ public class CudaWorkspace extends Nd4jWorkspace {
|
||||||
if (isDebug.get()) {
|
if (isDebug.get()) {
|
||||||
log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements);
|
log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements);
|
||||||
}
|
}
|
||||||
//Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
|
|
||||||
|
|
||||||
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
|
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
|
||||||
|
|
||||||
|
@ -258,6 +257,12 @@ public class CudaWorkspace extends Nd4jWorkspace {
|
||||||
return ptr;
|
return ptr;
|
||||||
} else {
|
} else {
|
||||||
// log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory, numElements);
|
// log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory, numElements);
|
||||||
|
if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
|
||||||
|
//log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory);
|
||||||
|
hostOffset.set(0);
|
||||||
|
//resetPlanned.set(true);
|
||||||
|
return alloc(requiredMemory, kind, type, initialize);
|
||||||
|
}
|
||||||
|
|
||||||
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
|
val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type);
|
||||||
|
|
||||||
|
|
|
@ -85,6 +85,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
}
|
}
|
||||||
|
|
||||||
public OpaqueDataBuffer getOpaqueDataBuffer() {
|
public OpaqueDataBuffer getOpaqueDataBuffer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return ptrDataBuffer;
|
return ptrDataBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,7 +107,8 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, this.type, pointer, specialPointer);
|
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, this.type, pointer, specialPointer);
|
||||||
this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length);
|
this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length);
|
||||||
|
|
||||||
Nd4j.getDeallocatorService().pickObject(this);
|
Nd4j.getDeallocatorService().pickObject(this);if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -473,6 +477,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
}
|
}
|
||||||
|
|
||||||
public BaseCudaDataBuffer(@NonNull DataBuffer underlyingBuffer, long length, long offset) {
|
public BaseCudaDataBuffer(@NonNull DataBuffer underlyingBuffer, long length, long offset) {
|
||||||
|
if (underlyingBuffer.wasClosed())
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
//this(length, underlyingBuffer.getElementSize(), offset);
|
//this(length, underlyingBuffer.getElementSize(), offset);
|
||||||
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
|
this.allocationMode = AllocationMode.MIXED_DATA_TYPES;
|
||||||
initTypeAndSize();
|
initTypeAndSize();
|
||||||
|
@ -1630,7 +1637,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||||
} else if (t == DataType.UINT32) {
|
} else if (t == DataType.UINT32) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
} else if (t == DataType.INT) {
|
} else if (t == DataType.INT) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||||
|
@ -1699,6 +1706,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -1750,6 +1760,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.k.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.jita.workspace;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class CudaWorkspaceTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCircularWorkspaceAsymmetry_1() {
|
||||||
|
// circular workspace mode
|
||||||
|
val configuration = WorkspaceConfiguration.builder().initialSize(10 * 1024 * 1024)
|
||||||
|
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
|
||||||
|
.policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
|
||||||
|
|
||||||
|
try (val ws = (CudaWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "circular_ws")) {
|
||||||
|
val array = Nd4j.create(DataType.FLOAT, 10, 10);
|
||||||
|
|
||||||
|
assertEquals(0, ws.getHostOffset());
|
||||||
|
assertNotEquals(0, ws.getDeviceOffset());
|
||||||
|
|
||||||
|
// we expect that this array has no data/buffer on HOST side
|
||||||
|
assertEquals(AffinityManager.Location.DEVICE, Nd4j.getAffinityManager().getActiveLocation(array));
|
||||||
|
|
||||||
|
// since this array doesn't have HOST buffer - it will allocate one now
|
||||||
|
array.getDouble(3L);
|
||||||
|
|
||||||
|
assertEquals(ws.getHostOffset(), ws.getDeviceOffset());
|
||||||
|
}
|
||||||
|
|
||||||
|
try (val ws = (CudaWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "circular_ws")) {
|
||||||
|
assertEquals(ws.getHostOffset(), ws.getDeviceOffset());
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCircularWorkspaceAsymmetry_2() {
|
||||||
|
// circular workspace mode
|
||||||
|
val configuration = WorkspaceConfiguration.builder().initialSize(10 * 1024 * 1024)
|
||||||
|
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
|
||||||
|
.policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
|
||||||
|
val root = Nd4j.create(DataType.FLOAT, 1000000).assign(119);
|
||||||
|
|
||||||
|
for (int e = 0; e < 100; e++) {
|
||||||
|
try (val ws = (CudaWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "circular_ws")) {
|
||||||
|
val array = Nd4j.create(DataType.FLOAT, root.shape());
|
||||||
|
array.assign(root);
|
||||||
|
|
||||||
|
array.data().getInt(3);
|
||||||
|
|
||||||
|
assertEquals(ws.getHostOffset(), ws.getDeviceOffset());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCircularWorkspaceAsymmetry_3() {
|
||||||
|
// circular workspace mode
|
||||||
|
val configuration = WorkspaceConfiguration.builder().initialSize(10 * 1024 * 1024)
|
||||||
|
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
|
||||||
|
.policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
|
||||||
|
val root = Nd4j.create(DataType.FLOAT, 1000000).assign(119);
|
||||||
|
|
||||||
|
for (int e = 0; e < 100; e++) {
|
||||||
|
try (val ws = (CudaWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "circular_ws")) {
|
||||||
|
val array = Nd4j.create(DataType.FLOAT, root.shape());
|
||||||
|
array.assign(root);
|
||||||
|
|
||||||
|
val second = Nd4j.create(DataType.FLOAT, root.shape());
|
||||||
|
|
||||||
|
array.data().getInt(3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -198,4 +198,27 @@ public class BaseCudaDataBufferTest extends BaseND4JTest {
|
||||||
// there shoul dbe no exceptions during execution
|
// there shoul dbe no exceptions during execution
|
||||||
assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), cnt.get());
|
assertEquals(Nd4j.getAffinityManager().getNumberOfDevices(), cnt.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testClose_1() {
|
||||||
|
val x = Nd4j.createFromArray(1, 2, 3);
|
||||||
|
|
||||||
|
x.close();
|
||||||
|
|
||||||
|
assertTrue(x.wasClosed());
|
||||||
|
assertTrue(x.data().wasClosed());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testClose_2() {
|
||||||
|
val x = Nd4j.create(DataType.FLOAT, 5, 6);
|
||||||
|
val row = x.getRow(1);
|
||||||
|
x.close();
|
||||||
|
|
||||||
|
assertTrue(x.wasClosed());
|
||||||
|
assertTrue(x.data().wasClosed());
|
||||||
|
|
||||||
|
assertTrue(row.wasClosed());
|
||||||
|
assertTrue(row.data().wasClosed());
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -61,6 +61,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
}
|
}
|
||||||
|
|
||||||
public OpaqueDataBuffer getOpaqueDataBuffer() {
|
public OpaqueDataBuffer getOpaqueDataBuffer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return ptrDataBuffer;
|
return ptrDataBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,7 +414,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
setIndexer(ShortIndexer.create((ShortPointer) pointer));
|
||||||
} else if (t == DataType.UINT32) {
|
} else if (t == DataType.UINT32) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
} else if (t == DataType.INT) {
|
} else if (t == DataType.INT) {
|
||||||
pointer = new PagedPointer(cptr, length).asIntPointer();
|
pointer = new PagedPointer(cptr, length).asIntPointer();
|
||||||
setIndexer(IntIndexer.create((IntPointer) pointer));
|
setIndexer(IntIndexer.create((IntPointer) pointer));
|
||||||
|
@ -514,7 +517,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
attached = true;
|
attached = true;
|
||||||
parentWorkspace = workspace;
|
parentWorkspace = workspace;
|
||||||
|
|
||||||
// FIXME: need unsigned indexer here
|
|
||||||
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
pointer = workspace.alloc(length * getElementSize(), dataType(), initialize).asIntPointer(); //new IntPointer(length());
|
||||||
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
setIndexer(UIntIndexer.create((IntPointer) pointer));
|
||||||
|
|
||||||
|
@ -882,6 +884,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
@ -932,6 +937,9 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
||||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
|
pointer = nPtr.asIntPointer();
|
||||||
|
indexer = UIntIndexer.create((IntPointer) pointer);
|
||||||
|
break;
|
||||||
case INT:
|
case INT:
|
||||||
pointer = nPtr.asIntPointer();
|
pointer = nPtr.asIntPointer();
|
||||||
indexer = IntIndexer.create((IntPointer) pointer);
|
indexer = IntIndexer.create((IntPointer) pointer);
|
||||||
|
|
|
@ -16619,6 +16619,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
|
@Namespace("sd::ops") public static class clipbyavgnorm_bp extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public clipbyavgnorm_bp(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public clipbyavgnorm_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public clipbyavgnorm_bp position(long position) {
|
||||||
|
return (clipbyavgnorm_bp)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public clipbyavgnorm_bp() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
// #if NOT_EXCLUDED(OP_cumsum)
|
// #if NOT_EXCLUDED(OP_cumsum)
|
||||||
|
|
|
@ -123,8 +123,61 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
//AB 2020/01/07 - Known issues
|
//AB 2020/01/07 - Known issues
|
||||||
"bitcast/from_float64_to_int64",
|
"bitcast/from_float64_to_int64",
|
||||||
"bitcast/from_rank2_float64_to_int64",
|
"bitcast/from_rank2_float64_to_int64",
|
||||||
"bitcast/from_float64_to_uint64"
|
"bitcast/from_float64_to_uint64",
|
||||||
};
|
|
||||||
|
|
||||||
|
//NEWLY ADDED TESTCASES from 27/04/2020
|
||||||
|
"non_max_suppression_v2/.*", "non_max_suppression/.*",
|
||||||
|
"random_gamma/.*",
|
||||||
|
"non_max_suppression_v5/.*",
|
||||||
|
"non_max_suppression_v4/.*",
|
||||||
|
"non_max_suppression_v3/.*",
|
||||||
|
"dropout/.*",
|
||||||
|
"max_pool_with_argmax/.*",
|
||||||
|
"conv2d_transpose/.*",
|
||||||
|
"Conv3DBackpropInputV2/.*",
|
||||||
|
"Conv3DBackpropInput/.*",
|
||||||
|
"mod/.*",
|
||||||
|
"leaky_relu/.*",
|
||||||
|
"DeepCopy/.*",
|
||||||
|
"empty/.*",
|
||||||
|
"ones_like/.*",
|
||||||
|
"is_non_decreasing/.*",
|
||||||
|
"div/.*",
|
||||||
|
"lgamma/.*",
|
||||||
|
"random_uniform/.*",
|
||||||
|
"random_uniform_int/.*",
|
||||||
|
"resize_area/.*",
|
||||||
|
"zeros_like_tf1/.*",
|
||||||
|
"Conv2DTranspose/.*",
|
||||||
|
"rgb_to_yuv/.*",
|
||||||
|
"rgb_to_grayscale/.*",
|
||||||
|
"rgb_to_yiq/.*",
|
||||||
|
"losses/.*",
|
||||||
|
"yiq_to_rgb/.*",
|
||||||
|
"yuv_to_rgb/.*",
|
||||||
|
"emptyArrayTests/.*",
|
||||||
|
"random_normal/.*",
|
||||||
|
"random_shuffle/.*",
|
||||||
|
"random_poisson_v2/.*",
|
||||||
|
"random_poisson/.*",
|
||||||
|
"random_crop/.*",
|
||||||
|
"compare_and_bitpack/.*",
|
||||||
|
"adjust_contrast/.*",
|
||||||
|
"confusion/.*",
|
||||||
|
"bitcast/.*",
|
||||||
|
"roll/.*",
|
||||||
|
"matrix_band_part/.*",
|
||||||
|
"conv3d_transpose_layers/.*",
|
||||||
|
"multinomial/.*",
|
||||||
|
"unsorted_segment/.*",
|
||||||
|
"cnn2d_nn/.*",
|
||||||
|
"truncatemod/.*",
|
||||||
|
"bincount/.*",
|
||||||
|
"slogdet/.*",
|
||||||
|
"adjust_contrast_v2/.*"
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||||
all arrays printed during execution.
|
all arrays printed during execution.
|
||||||
|
|
|
@ -8411,6 +8411,76 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
INDArray arr = Nd4j.create(db, new long[]{lengthElements});
|
||||||
|
|
||||||
arr.toStringFull();
|
arr.toStringFull();
|
||||||
|
arr.toString();
|
||||||
|
|
||||||
|
for(DataType dt2 : DataType.values()) {
|
||||||
|
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
INDArray a2 = arr.castTo(dt2);
|
||||||
|
a2.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateBufferFromByteBufferViews(){
|
||||||
|
|
||||||
|
for(DataType dt : DataType.values()){
|
||||||
|
if(dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
// System.out.println(dt);
|
||||||
|
|
||||||
|
int lengthBytes = 256;
|
||||||
|
int lengthElements = lengthBytes / dt.width();
|
||||||
|
ByteBuffer bb = ByteBuffer.allocateDirect(lengthBytes);
|
||||||
|
|
||||||
|
DataBuffer db = Nd4j.createBuffer(bb, dt, lengthElements, 0);
|
||||||
|
INDArray arr = Nd4j.create(db, new long[]{lengthElements/2, 2});
|
||||||
|
|
||||||
|
arr.toStringFull();
|
||||||
|
|
||||||
|
INDArray view = arr.get(NDArrayIndex.all(), NDArrayIndex.point(0));
|
||||||
|
INDArray view2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
|
||||||
|
|
||||||
|
view.toStringFull();
|
||||||
|
view2.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTypeCastingToString(){
|
||||||
|
|
||||||
|
for(DataType dt : DataType.values()) {
|
||||||
|
if (dt == DataType.COMPRESSED || dt == DataType.UTF8 || dt == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
INDArray a1 = Nd4j.create(dt, 10);
|
||||||
|
for(DataType dt2 : DataType.values()) {
|
||||||
|
if (dt2 == DataType.COMPRESSED || dt2 == DataType.UTF8 || dt2 == DataType.UNKNOWN)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
INDArray a2 = a1.castTo(dt2);
|
||||||
|
a2.toStringFull();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testShape0Casts(){
|
||||||
|
for(DataType dt : DataType.values()){
|
||||||
|
if(!dt.isNumerical())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
INDArray a1 = Nd4j.create(dt, 1,0,2);
|
||||||
|
|
||||||
|
for(DataType dt2 : DataType.values()){
|
||||||
|
if(!dt2.isNumerical())
|
||||||
|
continue;
|
||||||
|
INDArray a2 = a1.castTo(dt2);
|
||||||
|
|
||||||
|
assertArrayEquals(a1.shape(), a2.shape());
|
||||||
|
assertEquals(dt2, a2.dataType());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.memory.enums.*;
|
import org.nd4j.linalg.api.memory.enums.*;
|
||||||
|
@ -1219,6 +1220,30 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
||||||
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCircularWorkspaceAsymmetry_1() {
|
||||||
|
// nothing to test on CPU here
|
||||||
|
if (Nd4j.getEnvironment().isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// circular workspace mode
|
||||||
|
val configuration = WorkspaceConfiguration.builder().initialSize(10 * 1024 * 1024)
|
||||||
|
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.STRICT)
|
||||||
|
.policySpill(SpillPolicy.FAIL).policyLearning(LearningPolicy.NONE).build();
|
||||||
|
|
||||||
|
|
||||||
|
try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "circular_ws")) {
|
||||||
|
val array = Nd4j.create(DataType.FLOAT, 10, 10);
|
||||||
|
|
||||||
|
// we expect that this array has no data/buffer on HOST side
|
||||||
|
assertEquals(AffinityManager.Location.DEVICE, Nd4j.getAffinityManager().getActiveLocation(array));
|
||||||
|
|
||||||
|
// since this array doesn't have HOST buffer - it will allocate one now
|
||||||
|
array.getDouble(3L);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
|
|
Loading…
Reference in New Issue