Few fixes (#361)

* INDArray.close() fix for CPU

Signed-off-by: raver119 <raver119@gmail.com>

* - BroadcastableBoolOp introduced
- ConfusionMatrix now supports explicit DataType argument

Signed-off-by: raver119 <raver119@gmail.com>

* confusion_matrix: dtype is still optional

Signed-off-by: raver119 <raver119@gmail.com>

* disable bert tests in debug builds

Signed-off-by: raver119 <raver119@gmail.com>

* Affinity fix

Signed-off-by: raver119 <raver119@gmail.com>

* minor workspace tweak to allow close() on scoped out borrowed workspace

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-04-06 21:01:59 +03:00 committed by GitHub
parent 986ec4b51a
commit 04b2b4f9b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 430 additions and 52 deletions

View File

@ -0,0 +1,43 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
#ifndef SD_BROADCASTABLEBOOLOP_H
#define SD_BROADCASTABLEBOOLOP_H
#include <graph/Context.h>
#include "OpDescriptor.h"
#include "DeclarableOp.h"
#include "DeclarableCustomOp.h"
namespace sd {
namespace ops {
class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{
protected:
Nd4jStatus validateAndExecute(Context& block) override = 0;
public:
BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs);
ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override;
};
}
}
#endif //SD_BROADCASTABLEBOOLOP_H

View File

@ -23,7 +23,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(equals, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -23,7 +23,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(greater, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(greater_equal, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(less, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(less_equal, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -22,7 +22,7 @@
namespace sd { namespace sd {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(not_equals, 0, 0) { BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);

View File

@ -45,8 +45,8 @@ namespace sd {
weights = INPUT_VARIABLE(2); weights = INPUT_VARIABLE(2);
REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape"); REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape");
} }
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
output->assign(0.);
int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0); int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0);
int minLabel = labels->reduceNumber(reduce::Min).e<int>(0); int minLabel = labels->reduceNumber(reduce::Min).e<int>(0);
@ -64,11 +64,7 @@ namespace sd {
DECLARE_SHAPE_FN(confusion_matrix) { DECLARE_SHAPE_FN(confusion_matrix) {
auto labels = INPUT_VARIABLE(0); auto labels = INPUT_VARIABLE(0);
auto predictions = INPUT_VARIABLE(1); auto predictions = INPUT_VARIABLE(1);
auto dtype = block.dataType(); auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
dtype = sd::DataType::INT64; // dtype - should be a param with int argument
if (block.numI() > 1)
dtype = (sd::DataType)INT_ARG(1);
int numClasses = 0; int numClasses = 0;
if (block.getIArguments()->size() > 0) { if (block.getIArguments()->size() > 0) {

View File

@ -22,6 +22,7 @@
#define LIBND4J_HEADERS_BROADCASTABLE_H #define LIBND4J_HEADERS_BROADCASTABLE_H
#include <ops/declarable/BroadcastableOp.h> #include <ops/declarable/BroadcastableOp.h>
#include <ops/declarable/BroadcastableBoolOp.h>
#include <ops/declarable/headers/common.h> #include <ops/declarable/headers/common.h>
#include <ops/declarable/generic/helpers/BroadcastHelper.h> #include <ops/declarable/generic/helpers/BroadcastHelper.h>
@ -261,7 +262,7 @@ namespace sd {
* *
*/ */
#if NOT_EXCLUDED(OP_equals) #if NOT_EXCLUDED(OP_equals)
DECLARE_BROADCASTABLE_OP(equals, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0);
#endif #endif
/** /**
@ -269,7 +270,7 @@ namespace sd {
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f; * Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_not_equals) #if NOT_EXCLUDED(OP_not_equals)
DECLARE_BROADCASTABLE_OP(not_equals, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0);
#endif #endif
/** /**
@ -277,7 +278,7 @@ namespace sd {
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_less_equal) #if NOT_EXCLUDED(OP_less_equal)
DECLARE_BROADCASTABLE_OP(less_equal, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0);
#endif #endif
/** /**
@ -285,7 +286,7 @@ namespace sd {
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_greater_equal) #if NOT_EXCLUDED(OP_greater_equal)
DECLARE_BROADCASTABLE_OP(greater_equal, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0);
#endif #endif
/** /**
@ -293,7 +294,7 @@ namespace sd {
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f; * Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_less) #if NOT_EXCLUDED(OP_less)
DECLARE_BROADCASTABLE_OP(less, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0);
#endif #endif
/** /**
@ -301,7 +302,7 @@ namespace sd {
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f; * Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
*/ */
#if NOT_EXCLUDED(OP_greater) #if NOT_EXCLUDED(OP_greater)
DECLARE_BROADCASTABLE_OP(greater, 0, 0); DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0);
#endif #endif
/** /**

View File

@ -0,0 +1,72 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
#include <system/op_boilerplate.h>
#include <system/pointercast.h>
#include <ops/declarable/BroadcastableBoolOp.h>
#include <helpers/ShapeUtils.h>
namespace sd {
namespace ops {
BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) {
//
}
ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) {
auto shapeList = SHAPELIST();
auto x = inputShape->at(0);
auto y = inputShape->at(1);
sd::DataType dtype = sd::DataType::BOOL;
if(shape::isEmpty(x) || shape::isEmpty(y)) {
// this is edge case, [3, 4] + [] = []
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
return shapeList;
}
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else if (shape::isScalar(x) && shape::isScalar(y)) {
if (shape::rank(x) >= shape::rank(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
}
} else if (shape::equalsSoft(x, y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else if (shape::isScalar(x) && !shape::isScalar(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
} else if (!shape::isScalar(x) && shape::isScalar(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
} else if (ShapeUtils::areShapesBroadcastable(x, y)) {
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else {
// in this case we'll throw exception later
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
}
return shapeList;
}
}
}

View File

@ -1446,10 +1446,24 @@
};\ };\
REGISTER_H(NAME) REGISTER_H(NAME)
#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \
protected: \
void registerTypes(); \
Nd4jStatus validateAndExecute(Context& block); \
public:\
NAME(); \
};\
REGISTER_H(NAME)
#define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; \ #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; \
REGISTER_C(NAME) \ REGISTER_C(NAME) \
Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
#define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { }; \
REGISTER_C(NAME) \
Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
#define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS)

View File

@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests12, TestConfusionZero_1) {
//exp1.assign(1.); //exp1.assign(1.);
//exp2.assign(-2.); //exp2.assign(-2.);
sd::ops::confusion_matrix op; sd::ops::confusion_matrix op;
Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}); Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(output.equalsTo(exp)); ASSERT_TRUE(output.equalsTo(exp));

View File

@ -2374,7 +2374,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) {
auto expected = NDArrayFactory::create<double>('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); auto expected = NDArrayFactory::create<double>('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200});
sd::ops::confusion_matrix op; sd::ops::confusion_matrix op;
auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3, sd::DataType::DOUBLE}); auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE});
auto output = results.at(0); auto output = results.at(0);

View File

@ -91,6 +91,8 @@ TEST_F(PlaygroundTests, test_biasAdd_1) {
TEST_F(PlaygroundTests, test_bert_full_1) { TEST_F(PlaygroundTests, test_bert_full_1) {
#ifdef _RELEASE
// this test will run ONLY if this model exists // this test will run ONLY if this model exists
if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0) if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0)
return; return;
@ -147,10 +149,12 @@ TEST_F(PlaygroundTests, test_bert_full_1) {
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
*/ */
delete graph; delete graph;
#endif
} }
TEST_F(PlaygroundTests, test_bert_1) { TEST_F(PlaygroundTests, test_bert_1) {
#ifdef _RELEASE
// this test will run ONLY if this model exists // this test will run ONLY if this model exists
if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0) if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
return; return;
@ -206,9 +210,11 @@ TEST_F(PlaygroundTests, test_bert_1) {
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
*/ */
delete graph; delete graph;
#endif
} }
TEST_F(PlaygroundTests, test_bert_2) { TEST_F(PlaygroundTests, test_bert_2) {
#ifdef _RELEASE
// this test will run ONLY if this model exists // this test will run ONLY if this model exists
if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0)
return; return;
@ -256,6 +262,7 @@ TEST_F(PlaygroundTests, test_bert_2) {
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
*/ */
delete graph; delete graph;
#endif
} }

View File

@ -1930,7 +1930,6 @@ public abstract class BaseDataBuffer implements DataBuffer {
protected void release() { protected void release() {
this.released = true; this.released = true;
this.pointer.deallocate();
this.indexer = null; this.indexer = null;
this.pointer = null; this.pointer = null;
} }

View File

@ -580,6 +580,13 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
public void close() { public void close() {
// first we check if this workspace was borrowed. if yes - just close without reset. // first we check if this workspace was borrowed. if yes - just close without reset.
if (isBorrowed.get()) { if (isBorrowed.get()) {
if (tagScope.get() > 0) {
if (tagScope.decrementAndGet() == 0) {
Nd4j.getMemoryManager().setCurrentWorkspace(this);
}
return;
}
isBorrowed.set(false); isBorrowed.set(false);
Nd4j.getMemoryManager().setCurrentWorkspace(borrowingWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(borrowingWorkspace);
return; return;

View File

@ -42,6 +42,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){ public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){
super(new INDArray[]{labels, predicted}, null); super(new INDArray[]{labels, predicted}, null);
this.outputType = dataType; this.outputType = dataType;
addDArgument(dataType);
} }
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){ public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){
@ -66,6 +67,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
if(numClasses != null) { if(numClasses != null) {
addIArgument(numClasses); addIArgument(numClasses);
} }
addDArgument(dataType);
} }
@ -77,6 +79,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){ public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
super(null, sameDiff, new SDVariable[]{labels, pred}); super(null, sameDiff, new SDVariable[]{labels, pred});
this.outputType = dataType; this.outputType = dataType;
addDArgument(dataType);
} }
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights){ public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights){

View File

@ -209,4 +209,11 @@ public class OpaqueDataBuffer extends Pointer {
public void syncToPrimary() { public void syncToPrimary() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this);
} }
/**
* This method releases underlying buffer
*/
public void closeBuffer() {
NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(this);
}
} }

View File

@ -72,9 +72,16 @@ public class CudaAffinityManager extends BasicAffinityManager {
*/ */
@Override @Override
public Integer getDeviceForThread(long threadId) { public Integer getDeviceForThread(long threadId) {
val id = affinityMap.get(threadId); Integer id = affinityMap.get(threadId);
if (id == null) if (id == null) {
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet"); // if this is current thread - we're still able to fetch id from native side, and update map
if (threadId == Thread.currentThread().getId()) {
id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
affinityMap.put(Long.valueOf(threadId), id);
} else
// TODO: we should get rid of this method, and forbid such kind of queries
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
}
return id; return id;
} }

View File

@ -1792,11 +1792,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override @Override
protected void release() { protected void release() {
if (!released) { if (!released) {
//AtomicAllocator.getInstance().freeMemory(allocationPoint);n ptrDataBuffer.closeBuffer();
NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(allocationPoint.getPtrDataBuffer());
allocationPoint.setReleased(true); allocationPoint.setReleased(true);
} }
released = true;
super.release();
} }
/* /*

View File

@ -9553,6 +9553,50 @@ public static final int PREALLOC_SIZE = 33554432;
// #endif //LIBND4J_BROADCASTABLEOP_H // #endif //LIBND4J_BROADCASTABLEOP_H
// Parsed from ops/declarable/BroadcastableBoolOp.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
// #ifndef SD_BROADCASTABLEBOOLOP_H
// #define SD_BROADCASTABLEBOOLOP_H
// #include <graph/Context.h>
// #include "OpDescriptor.h"
// #include "DeclarableOp.h"
// #include "DeclarableCustomOp.h"
@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BroadcastableBoolOp(Pointer p) { super(p); }
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif //SD_BROADCASTABLEBOOLOP_H
// Parsed from helpers/OpArgsHolder.h // Parsed from helpers/OpArgsHolder.h
/******************************************************************************* /*******************************************************************************

View File

@ -77,6 +77,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
"ops/declarable/OpDescriptor.h", "ops/declarable/OpDescriptor.h",
"ops/declarable/PlatformHelper.h", "ops/declarable/PlatformHelper.h",
"ops/declarable/BroadcastableOp.h", "ops/declarable/BroadcastableOp.h",
"ops/declarable/BroadcastableBoolOp.h",
"helpers/OpArgsHolder.h", "helpers/OpArgsHolder.h",
"ops/declarable/DeclarableOp.h", "ops/declarable/DeclarableOp.h",
"ops/declarable/DeclarableListOp.h", "ops/declarable/DeclarableListOp.h",

View File

@ -837,6 +837,12 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
this(data, true, workspace); this(data, true, workspace);
} }
@Override
protected void release() {
ptrDataBuffer.closeBuffer();
super.release();
}
/** /**
* Reallocate the native memory of the buffer * Reallocate the native memory of the buffer
* @param length the new length of the buffer * @param length the new length of the buffer

View File

@ -6,6 +6,7 @@ import java.nio.*;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*; import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*; import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*; import static org.bytedeco.openblas.global.openblas.*;
@ -11406,10 +11407,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// }; // };
// REGISTER_H(NAME) // REGISTER_H(NAME)
// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp {
// protected:
// void registerTypes();
// Nd4jStatus validateAndExecute(Context& block);
// public:
// NAME();
// };
// REGISTER_H(NAME)
// #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; // #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { };
// REGISTER_C(NAME) // REGISTER_C(NAME)
// Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block) // Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
// #define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { };
// REGISTER_C(NAME)
// Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
// #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS) // #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS)
@ -11871,6 +11886,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif //LIBND4J_BROADCASTABLEOP_H // #endif //LIBND4J_BROADCASTABLEOP_H
// Parsed from ops/declarable/BroadcastableBoolOp.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver on 6/6/2018.
//
// #ifndef SD_BROADCASTABLEBOOLOP_H
// #define SD_BROADCASTABLEBOOLOP_H
// #include <graph/Context.h>
// #include "OpDescriptor.h"
// #include "DeclarableOp.h"
// #include "DeclarableCustomOp.h"
@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public BroadcastableBoolOp(Pointer p) { super(p); }
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif //SD_BROADCASTABLEBOOLOP_H
// Parsed from ops/declarable/DeclarableOp.h // Parsed from ops/declarable/DeclarableOp.h
/******************************************************************************* /*******************************************************************************
@ -13636,6 +13695,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #define LIBND4J_HEADERS_BROADCASTABLE_H // #define LIBND4J_HEADERS_BROADCASTABLE_H
// #include <ops/declarable/BroadcastableOp.h> // #include <ops/declarable/BroadcastableOp.h>
// #include <ops/declarable/BroadcastableBoolOp.h>
// #include <ops/declarable/headers/common.h> // #include <ops/declarable/headers/common.h>
// #include <ops/declarable/generic/helpers/BroadcastHelper.h> // #include <ops/declarable/generic/helpers/BroadcastHelper.h>
// TODO: make broadcastables separate class // TODO: make broadcastables separate class
@ -14317,7 +14377,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* *
*/ */
// #if NOT_EXCLUDED(OP_equals) // #if NOT_EXCLUDED(OP_equals)
@Namespace("sd::ops") public static class equals extends BroadcastableOp { @Namespace("sd::ops") public static class equals extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public equals(Pointer p) { super(p); } public equals(Pointer p) { super(p); }
@ -14338,7 +14398,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f; * Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
*/ */
// #if NOT_EXCLUDED(OP_not_equals) // #if NOT_EXCLUDED(OP_not_equals)
@Namespace("sd::ops") public static class not_equals extends BroadcastableOp { @Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public not_equals(Pointer p) { super(p); } public not_equals(Pointer p) { super(p); }
@ -14359,7 +14419,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
*/ */
// #if NOT_EXCLUDED(OP_less_equal) // #if NOT_EXCLUDED(OP_less_equal)
@Namespace("sd::ops") public static class less_equal extends BroadcastableOp { @Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public less_equal(Pointer p) { super(p); } public less_equal(Pointer p) { super(p); }
@ -14380,7 +14440,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f; * Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
*/ */
// #if NOT_EXCLUDED(OP_greater_equal) // #if NOT_EXCLUDED(OP_greater_equal)
@Namespace("sd::ops") public static class greater_equal extends BroadcastableOp { @Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public greater_equal(Pointer p) { super(p); } public greater_equal(Pointer p) { super(p); }
@ -14401,7 +14461,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f; * Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
*/ */
// #if NOT_EXCLUDED(OP_less) // #if NOT_EXCLUDED(OP_less)
@Namespace("sd::ops") public static class less extends BroadcastableOp { @Namespace("sd::ops") public static class less extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public less(Pointer p) { super(p); } public less(Pointer p) { super(p); }
@ -14422,7 +14482,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f; * Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
*/ */
// #if NOT_EXCLUDED(OP_greater) // #if NOT_EXCLUDED(OP_greater)
@Namespace("sd::ops") public static class greater extends BroadcastableOp { @Namespace("sd::ops") public static class greater extends BroadcastableBoolOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public greater(Pointer p) { super(p); } public greater(Pointer p) { super(p); }
@ -16672,6 +16732,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
@Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public mergemax_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public mergemax_bp position(long position) {
return (mergemax_bp)super.position(position);
}
public mergemax_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif // #endif
/* /*
* Complete tensor with max indices merged from all input tensors list * Complete tensor with max indices merged from all input tensors list
@ -16714,6 +16789,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
@Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public mergeadd_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public mergeadd_bp position(long position) {
return (mergeadd_bp)super.position(position);
}
public mergeadd_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif // #endif
// #if NOT_EXCLUDED(OP_mergeavg) // #if NOT_EXCLUDED(OP_mergeavg)
@ -16732,6 +16822,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
@Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public mergeavg_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public mergeavg_bp position(long position) {
return (mergeavg_bp)super.position(position);
}
public mergeavg_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif // #endif
// #if NOT_EXCLUDED(OP_scatter_update) // #if NOT_EXCLUDED(OP_scatter_update)
@ -19074,23 +19179,40 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* - 2D matrix MxN * - 2D matrix MxN
* - 1D vector with N elements * - 1D vector with N elements
* output value - 2D matrix NxN as multiply of matrixes and add vector * output value - 2D matrix NxN as multiply of matrixes and add vector
* Int args:
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
*/ */
// #if NOT_EXCLUDED(OP_xw_plus_b) // #if NOT_EXCLUDED(OP_xw_plus_b)
@Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp { @Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public xw_plus_b(Pointer p) { super(p); } public xw_plus_b(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */ /** Native array allocator. Access with {@link Pointer#position(long)}. */
public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); } public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size); private native void allocateArray(long size);
@Override public xw_plus_b position(long position) { @Override public xw_plus_b position(long position) {
return (xw_plus_b)super.position(position); return (xw_plus_b)super.position(position);
} }
public xw_plus_b() { super((Pointer)null); allocate(); } public xw_plus_b() { super((Pointer)null); allocate(); }
private native void allocate(); private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
@Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public xw_plus_b_bp(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public xw_plus_b_bp position(long position) {
return (xw_plus_b_bp)super.position(position);
}
public xw_plus_b_bp() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif // #endif
/** /**

View File

@ -82,6 +82,7 @@ import java.util.Scanner;
"ops/declarable/OpDescriptor.h", "ops/declarable/OpDescriptor.h",
"ops/declarable/PlatformHelper.h", "ops/declarable/PlatformHelper.h",
"ops/declarable/BroadcastableOp.h", "ops/declarable/BroadcastableOp.h",
"ops/declarable/BroadcastableBoolOp.h",
"ops/declarable/DeclarableOp.h", "ops/declarable/DeclarableOp.h",
"ops/declarable/DeclarableListOp.h", "ops/declarable/DeclarableListOp.h",
"ops/declarable/DeclarableReductionOp.h", "ops/declarable/DeclarableReductionOp.h",

View File

@ -24,9 +24,11 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -316,6 +318,17 @@ public class BasicBroadcastTests extends BaseNd4jTest {
assertEquals(exp, sum); assertEquals(exp, sum);
} }
@Test
public void testBroadcatableBool_1() {
val op = DynamicCustomOp.builder("greater_equal")
.addInputs(Nd4j.create(DataType.FLOAT, 3), Nd4j.create(DataType.FLOAT, 3))
.build();
val l = op.calculateOutputShape();
assertEquals(1, l.size());
assertEquals(DataType.BOOL, l.get(0).dataType());
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';

View File

@ -36,6 +36,9 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import java.util.ArrayList;
import java.util.Arrays;
import static org.junit.Assert.*; import static org.junit.Assert.*;
/** /**
@ -298,6 +301,40 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
log.info("{} ns", ((timeEnd - timeStart) / (double) iterations)); log.info("{} ns", ((timeEnd - timeStart) / (double) iterations));
} }
@Test
public void testWorkspaceOrder_1(){
WorkspaceConfiguration conf = WorkspaceConfiguration.builder()
.initialSize(1_000_000)
.overallocationLimit(0.05)
.policyLearning(LearningPolicy.NONE)
.build();
val exp = Arrays.asList("outer", null, "outer", "inner", "outer", null);
val res = new ArrayList<String>();
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "outer")){
try(MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "inner")){
try(MemoryWorkspace ws3 = ws.notifyScopeBorrowed()){
System.out.println("X: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
try(MemoryWorkspace ws4 = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
System.out.println("A: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null)
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
}
System.out.println("B: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
}
System.out.println("C: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //inner
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
}
System.out.println("D: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
}
System.out.println("E: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null)
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
assertEquals(exp, res);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';

View File

@ -616,19 +616,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest {
} }
@Test @Test
@Ignore("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?")
public void testNestedWorkspaces11() { public void testNestedWorkspaces11() {
for (int x = 1; x < 10; x++) { for (int x = 1; x < 10; x++) {
try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) { try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
INDArray array1 = Nd4j.create(100 * x); INDArray array1 = Nd4j.create(100 * x);
for (int i = 1; i < 10; i++) { for (int i = 1; i < 10; i++) {
try (MemoryWorkspace ws2 = try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
INDArray array2 = Nd4j.create(100 * x); INDArray array2 = Nd4j.create(100 * x);
for (int e = 1; e < 10; e++) { for (int e = 1; e < 10; e++) {
try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager() try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) {
.getWorkspaceForCurrentThread(basicConfiguration, "WS_1")
.notifyScopeBorrowed()) {
INDArray array3 = Nd4j.create(100 * x); INDArray array3 = Nd4j.create(100 * x);
} }
} }