Few fixes (#361)
* INDArray.close() fix for CPU Signed-off-by: raver119 <raver119@gmail.com> * - BroadcastableBoolOp introduced - ConfusionMatrix now supports explicit DataType argument Signed-off-by: raver119 <raver119@gmail.com> * confusion_matrix: dtype is still optional Signed-off-by: raver119 <raver119@gmail.com> * disable bert tests in debug builds Signed-off-by: raver119 <raver119@gmail.com> * Affinity fix Signed-off-by: raver119 <raver119@gmail.com> * minor workspace tweak to allow close() on scoped out borrowed workspace Signed-off-by: raver119 <raver119@gmail.com>master
parent
986ec4b51a
commit
04b2b4f9b6
|
@ -0,0 +1,43 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver on 6/6/2018.
|
||||
//
|
||||
|
||||
#ifndef SD_BROADCASTABLEBOOLOP_H
|
||||
#define SD_BROADCASTABLEBOOLOP_H
|
||||
|
||||
#include <graph/Context.h>
|
||||
#include "OpDescriptor.h"
|
||||
#include "DeclarableOp.h"
|
||||
#include "DeclarableCustomOp.h"
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
class ND4J_EXPORT BroadcastableBoolOp : public DeclarableCustomOp{
|
||||
protected:
|
||||
Nd4jStatus validateAndExecute(Context& block) override = 0;
|
||||
public:
|
||||
BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs);
|
||||
|
||||
ShapeList *calculateOutputShape(ShapeList *inputShape, sd::graph::Context& block) override;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //SD_BROADCASTABLEBOOLOP_H
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(equals, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(equals, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(greater, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(greater, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(greater_equal, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(greater_equal, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(less, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(less, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(less_equal, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(less_equal, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(not_equals, 0, 0) {
|
||||
BROADCASTABLE_BOOL_OP_IMPL(not_equals, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -45,8 +45,8 @@ namespace sd {
|
|||
weights = INPUT_VARIABLE(2);
|
||||
REQUIRE_TRUE(weights->isSameShape(predictions),0, "CONFUSION_MATRIX: Weights and predictions should have equal shape");
|
||||
}
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
output->assign(0.);
|
||||
auto output = OUTPUT_NULLIFIED(0);
|
||||
|
||||
int minPrediction = predictions->reduceNumber(reduce::Min).e<int>(0);
|
||||
int minLabel = labels->reduceNumber(reduce::Min).e<int>(0);
|
||||
|
||||
|
@ -64,11 +64,7 @@ namespace sd {
|
|||
DECLARE_SHAPE_FN(confusion_matrix) {
|
||||
auto labels = INPUT_VARIABLE(0);
|
||||
auto predictions = INPUT_VARIABLE(1);
|
||||
auto dtype = block.dataType();
|
||||
dtype = sd::DataType::INT64; // dtype - should be a param with int argument
|
||||
if (block.numI() > 1)
|
||||
dtype = (sd::DataType)INT_ARG(1);
|
||||
|
||||
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||
int numClasses = 0;
|
||||
|
||||
if (block.getIArguments()->size() > 0) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#define LIBND4J_HEADERS_BROADCASTABLE_H
|
||||
|
||||
#include <ops/declarable/BroadcastableOp.h>
|
||||
#include <ops/declarable/BroadcastableBoolOp.h>
|
||||
#include <ops/declarable/headers/common.h>
|
||||
#include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
|
||||
|
@ -261,7 +262,7 @@ namespace sd {
|
|||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_equals)
|
||||
DECLARE_BROADCASTABLE_OP(equals, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(equals, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -269,7 +270,7 @@ namespace sd {
|
|||
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_not_equals)
|
||||
DECLARE_BROADCASTABLE_OP(not_equals, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(not_equals, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -277,7 +278,7 @@ namespace sd {
|
|||
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_less_equal)
|
||||
DECLARE_BROADCASTABLE_OP(less_equal, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(less_equal, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -285,7 +286,7 @@ namespace sd {
|
|||
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_greater_equal)
|
||||
DECLARE_BROADCASTABLE_OP(greater_equal, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(greater_equal, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -293,7 +294,7 @@ namespace sd {
|
|||
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_less)
|
||||
DECLARE_BROADCASTABLE_OP(less, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(less, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -301,7 +302,7 @@ namespace sd {
|
|||
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_greater)
|
||||
DECLARE_BROADCASTABLE_OP(greater, 0, 0);
|
||||
DECLARE_BROADCASTABLE_BOOL_OP(greater, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver on 6/6/2018.
|
||||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#include <system/pointercast.h>
|
||||
#include <ops/declarable/BroadcastableBoolOp.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
BroadcastableBoolOp::BroadcastableBoolOp(const char *name, int numTArgs, int numIArgs) : DeclarableCustomOp::DeclarableCustomOp(2, 1, name, false, numTArgs, numIArgs) {
|
||||
//
|
||||
}
|
||||
|
||||
ShapeList *BroadcastableBoolOp::calculateOutputShape(ShapeList *inputShape, sd::graph::Context &block) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto x = inputShape->at(0);
|
||||
auto y = inputShape->at(1);
|
||||
sd::DataType dtype = sd::DataType::BOOL;
|
||||
|
||||
if(shape::isEmpty(x) || shape::isEmpty(y)) {
|
||||
// this is edge case, [3, 4] + [] = []
|
||||
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
|
||||
return shapeList;
|
||||
}
|
||||
|
||||
Nd4jLong *newshape = nullptr;
|
||||
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||
} else if (shape::isScalar(x) && shape::isScalar(y)) {
|
||||
if (shape::rank(x) >= shape::rank(y)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||
} else {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
|
||||
}
|
||||
} else if (shape::equalsSoft(x, y)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||
} else if (shape::isScalar(x) && !shape::isScalar(y)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(y, dtype)));
|
||||
} else if (!shape::isScalar(x) && shape::isScalar(y)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||
} else if (ShapeUtils::areShapesBroadcastable(x, y)) {
|
||||
Nd4jLong *newshape = nullptr;
|
||||
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||
} else {
|
||||
// in this case we'll throw exception later
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1446,10 +1446,24 @@
|
|||
};\
|
||||
REGISTER_H(NAME)
|
||||
|
||||
#define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp { \
|
||||
protected: \
|
||||
void registerTypes(); \
|
||||
Nd4jStatus validateAndExecute(Context& block); \
|
||||
public:\
|
||||
NAME(); \
|
||||
};\
|
||||
REGISTER_H(NAME)
|
||||
|
||||
|
||||
#define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { }; \
|
||||
REGISTER_C(NAME) \
|
||||
Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
|
||||
|
||||
#define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { }; \
|
||||
REGISTER_C(NAME) \
|
||||
Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
|
||||
|
||||
|
||||
#define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS)
|
||||
|
||||
|
|
|
@ -515,7 +515,7 @@ TEST_F(DeclarableOpsTests12, TestConfusionZero_1) {
|
|||
//exp1.assign(1.);
|
||||
//exp2.assign(-2.);
|
||||
sd::ops::confusion_matrix op;
|
||||
Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {});
|
||||
Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(output.equalsTo(exp));
|
||||
|
|
|
@ -2374,7 +2374,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) {
|
|||
auto expected = NDArrayFactory::create<double>('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200});
|
||||
|
||||
sd::ops::confusion_matrix op;
|
||||
auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3, sd::DataType::DOUBLE});
|
||||
auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE});
|
||||
auto output = results.at(0);
|
||||
|
||||
|
||||
|
|
|
@ -91,6 +91,8 @@ TEST_F(PlaygroundTests, test_biasAdd_1) {
|
|||
|
||||
|
||||
TEST_F(PlaygroundTests, test_bert_full_1) {
|
||||
#ifdef _RELEASE
|
||||
|
||||
// this test will run ONLY if this model exists
|
||||
if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0)
|
||||
return;
|
||||
|
@ -147,10 +149,12 @@ TEST_F(PlaygroundTests, test_bert_full_1) {
|
|||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
*/
|
||||
delete graph;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
TEST_F(PlaygroundTests, test_bert_1) {
|
||||
#ifdef _RELEASE
|
||||
// this test will run ONLY if this model exists
|
||||
if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
|
||||
return;
|
||||
|
@ -206,9 +210,11 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
|||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
*/
|
||||
delete graph;
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, test_bert_2) {
|
||||
#ifdef _RELEASE
|
||||
// this test will run ONLY if this model exists
|
||||
if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0)
|
||||
return;
|
||||
|
@ -256,6 +262,7 @@ TEST_F(PlaygroundTests, test_bert_2) {
|
|||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
*/
|
||||
delete graph;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1930,7 +1930,6 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
protected void release() {
|
||||
this.released = true;
|
||||
this.pointer.deallocate();
|
||||
this.indexer = null;
|
||||
this.pointer = null;
|
||||
}
|
||||
|
|
|
@ -580,6 +580,13 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace {
|
|||
public void close() {
|
||||
// first we check if this workspace was borrowed. if yes - just close without reset.
|
||||
if (isBorrowed.get()) {
|
||||
if (tagScope.get() > 0) {
|
||||
if (tagScope.decrementAndGet() == 0) {
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(this);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
isBorrowed.set(false);
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(borrowingWorkspace);
|
||||
return;
|
||||
|
|
|
@ -42,6 +42,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
|||
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, @NonNull DataType dataType){
|
||||
super(new INDArray[]{labels, predicted}, null);
|
||||
this.outputType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public ConfusionMatrix(@NonNull INDArray labels, @NonNull INDArray predicted, int numClasses){
|
||||
|
@ -66,6 +67,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
|||
if(numClasses != null) {
|
||||
addIArgument(numClasses);
|
||||
}
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
|
||||
|
@ -77,6 +79,7 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
|||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, DataType dataType){
|
||||
super(null, sameDiff, new SDVariable[]{labels, pred});
|
||||
this.outputType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public ConfusionMatrix(SameDiff sameDiff, SDVariable labels, SDVariable pred, SDVariable weights){
|
||||
|
|
|
@ -209,4 +209,11 @@ public class OpaqueDataBuffer extends Pointer {
|
|||
public void syncToPrimary() {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method releases underlying buffer
|
||||
*/
|
||||
public void closeBuffer() {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(this);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,9 +72,16 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
*/
|
||||
@Override
|
||||
public Integer getDeviceForThread(long threadId) {
|
||||
val id = affinityMap.get(threadId);
|
||||
if (id == null)
|
||||
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
|
||||
Integer id = affinityMap.get(threadId);
|
||||
if (id == null) {
|
||||
// if this is current thread - we're still able to fetch id from native side, and update map
|
||||
if (threadId == Thread.currentThread().getId()) {
|
||||
id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
|
||||
affinityMap.put(Long.valueOf(threadId), id);
|
||||
} else
|
||||
// TODO: we should get rid of this method, and forbid such kind of queries
|
||||
throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet");
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
|
|
@ -1792,11 +1792,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
@Override
|
||||
protected void release() {
|
||||
if (!released) {
|
||||
//AtomicAllocator.getInstance().freeMemory(allocationPoint);n
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbClose(allocationPoint.getPtrDataBuffer());
|
||||
ptrDataBuffer.closeBuffer();
|
||||
allocationPoint.setReleased(true);
|
||||
}
|
||||
released = true;
|
||||
|
||||
super.release();
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -9553,6 +9553,50 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
// #endif //LIBND4J_BROADCASTABLEOP_H
|
||||
|
||||
|
||||
// Parsed from ops/declarable/BroadcastableBoolOp.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver on 6/6/2018.
|
||||
//
|
||||
|
||||
// #ifndef SD_BROADCASTABLEBOOLOP_H
|
||||
// #define SD_BROADCASTABLEBOOLOP_H
|
||||
|
||||
// #include <graph/Context.h>
|
||||
// #include "OpDescriptor.h"
|
||||
// #include "DeclarableOp.h"
|
||||
// #include "DeclarableCustomOp.h"
|
||||
@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public BroadcastableBoolOp(Pointer p) { super(p); }
|
||||
|
||||
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// #endif //SD_BROADCASTABLEBOOLOP_H
|
||||
|
||||
|
||||
// Parsed from helpers/OpArgsHolder.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
|
|
@ -76,7 +76,8 @@ import org.bytedeco.javacpp.tools.InfoMapper;
|
|||
"ops/InputType.h",
|
||||
"ops/declarable/OpDescriptor.h",
|
||||
"ops/declarable/PlatformHelper.h",
|
||||
"ops/declarable/BroadcastableOp.h",
|
||||
"ops/declarable/BroadcastableOp.h",
|
||||
"ops/declarable/BroadcastableBoolOp.h",
|
||||
"helpers/OpArgsHolder.h",
|
||||
"ops/declarable/DeclarableOp.h",
|
||||
"ops/declarable/DeclarableListOp.h",
|
||||
|
|
|
@ -837,6 +837,12 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
|
|||
this(data, true, workspace);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void release() {
|
||||
ptrDataBuffer.closeBuffer();
|
||||
super.release();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reallocate the native memory of the buffer
|
||||
* @param length the new length of the buffer
|
||||
|
|
|
@ -6,6 +6,7 @@ import java.nio.*;
|
|||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.javacpp.annotation.*;
|
||||
|
||||
import static org.bytedeco.javacpp.presets.javacpp.*;
|
||||
import static org.bytedeco.openblas.global.openblas_nolapack.*;
|
||||
import static org.bytedeco.openblas.global.openblas.*;
|
||||
|
||||
|
@ -11406,10 +11407,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// };
|
||||
// REGISTER_H(NAME)
|
||||
|
||||
// #define DECLARE_BROADCASTABLE_BOOL_OP(NAME,TARGS, IARGS) class ND4J_EXPORT NAME: public sd::ops::BroadcastableBoolOp {
|
||||
// protected:
|
||||
// void registerTypes();
|
||||
// Nd4jStatus validateAndExecute(Context& block);
|
||||
// public:
|
||||
// NAME();
|
||||
// };
|
||||
// REGISTER_H(NAME)
|
||||
|
||||
|
||||
// #define BROADCASTABLE_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableOp(#NAME, TARGS, IARGS) { };
|
||||
// REGISTER_C(NAME)
|
||||
// Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
|
||||
|
||||
// #define BROADCASTABLE_BOOL_OP_IMPL(NAME, TARGS, IARGS) NAME::NAME(): sd::ops::BroadcastableBoolOp(#NAME, TARGS, IARGS) { };
|
||||
// REGISTER_C(NAME)
|
||||
// Nd4jStatus sd::ops::NAME::validateAndExecute(sd::graph::Context& block)
|
||||
|
||||
|
||||
// #define DECLARE_DEVICE_OP(NAME, NIN, NOUT, INPLACEABLE, TARGS, IARGS)
|
||||
|
||||
|
@ -11871,6 +11886,50 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #endif //LIBND4J_BROADCASTABLEOP_H
|
||||
|
||||
|
||||
// Parsed from ops/declarable/BroadcastableBoolOp.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver on 6/6/2018.
|
||||
//
|
||||
|
||||
// #ifndef SD_BROADCASTABLEBOOLOP_H
|
||||
// #define SD_BROADCASTABLEBOOLOP_H
|
||||
|
||||
// #include <graph/Context.h>
|
||||
// #include "OpDescriptor.h"
|
||||
// #include "DeclarableOp.h"
|
||||
// #include "DeclarableCustomOp.h"
|
||||
@Namespace("sd::ops") public static class BroadcastableBoolOp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public BroadcastableBoolOp(Pointer p) { super(p); }
|
||||
|
||||
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// #endif //SD_BROADCASTABLEBOOLOP_H
|
||||
|
||||
|
||||
// Parsed from ops/declarable/DeclarableOp.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
@ -13636,6 +13695,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #define LIBND4J_HEADERS_BROADCASTABLE_H
|
||||
|
||||
// #include <ops/declarable/BroadcastableOp.h>
|
||||
// #include <ops/declarable/BroadcastableBoolOp.h>
|
||||
// #include <ops/declarable/headers/common.h>
|
||||
// #include <ops/declarable/generic/helpers/BroadcastHelper.h>
|
||||
// TODO: make broadcastables separate class
|
||||
|
@ -14317,7 +14377,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
*
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_equals)
|
||||
@Namespace("sd::ops") public static class equals extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class equals extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public equals(Pointer p) { super(p); }
|
||||
|
@ -14338,7 +14398,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Math is: _x != _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_not_equals)
|
||||
@Namespace("sd::ops") public static class not_equals extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class not_equals extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public not_equals(Pointer p) { super(p); }
|
||||
|
@ -14359,7 +14419,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Math is: _x <= _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_less_equal)
|
||||
@Namespace("sd::ops") public static class less_equal extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class less_equal extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public less_equal(Pointer p) { super(p); }
|
||||
|
@ -14380,7 +14440,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Math is: _x >= _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_greater_equal)
|
||||
@Namespace("sd::ops") public static class greater_equal extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class greater_equal extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public greater_equal(Pointer p) { super(p); }
|
||||
|
@ -14401,7 +14461,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Math is: _x < _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_less)
|
||||
@Namespace("sd::ops") public static class less extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class less extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public less(Pointer p) { super(p); }
|
||||
|
@ -14422,7 +14482,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* Math is: _x > _y ? (T) 1.0f : (T) 0.0f;
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_greater)
|
||||
@Namespace("sd::ops") public static class greater extends BroadcastableOp {
|
||||
@Namespace("sd::ops") public static class greater extends BroadcastableBoolOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public greater(Pointer p) { super(p); }
|
||||
|
@ -16672,6 +16732,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
@Namespace("sd::ops") public static class mergemax_bp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public mergemax_bp(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public mergemax_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public mergemax_bp position(long position) {
|
||||
return (mergemax_bp)super.position(position);
|
||||
}
|
||||
|
||||
public mergemax_bp() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
/*
|
||||
* Complete tensor with max indices merged from all input tensors list
|
||||
|
@ -16714,6 +16789,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
@Namespace("sd::ops") public static class mergeadd_bp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public mergeadd_bp(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public mergeadd_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public mergeadd_bp position(long position) {
|
||||
return (mergeadd_bp)super.position(position);
|
||||
}
|
||||
|
||||
public mergeadd_bp() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
// #if NOT_EXCLUDED(OP_mergeavg)
|
||||
|
@ -16732,6 +16822,21 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
@Namespace("sd::ops") public static class mergeavg_bp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public mergeavg_bp(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public mergeavg_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public mergeavg_bp position(long position) {
|
||||
return (mergeavg_bp)super.position(position);
|
||||
}
|
||||
|
||||
public mergeavg_bp() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
// #if NOT_EXCLUDED(OP_scatter_update)
|
||||
|
@ -19074,23 +19179,40 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* - 2D matrix MxN
|
||||
* - 1D vector with N elements
|
||||
* output value - 2D matrix NxN as multiply of matrixes and add vector
|
||||
* Int args:
|
||||
* 0 - optional switcher of weights format, if int arg == 1 - mkldnn, else mmul
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_xw_plus_b)
|
||||
@Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public xw_plus_b(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public xw_plus_b position(long position) {
|
||||
return (xw_plus_b)super.position(position);
|
||||
}
|
||||
|
||||
@Namespace("sd::ops") public static class xw_plus_b extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public xw_plus_b(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public xw_plus_b(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public xw_plus_b position(long position) {
|
||||
return (xw_plus_b)super.position(position);
|
||||
}
|
||||
|
||||
public xw_plus_b() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
@Namespace("sd::ops") public static class xw_plus_b_bp extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public xw_plus_b_bp(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public xw_plus_b_bp(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public xw_plus_b_bp position(long position) {
|
||||
return (xw_plus_b_bp)super.position(position);
|
||||
}
|
||||
|
||||
public xw_plus_b_bp() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -81,7 +81,8 @@ import java.util.Scanner;
|
|||
"ops/InputType.h",
|
||||
"ops/declarable/OpDescriptor.h",
|
||||
"ops/declarable/PlatformHelper.h",
|
||||
"ops/declarable/BroadcastableOp.h",
|
||||
"ops/declarable/BroadcastableOp.h",
|
||||
"ops/declarable/BroadcastableBoolOp.h",
|
||||
"ops/declarable/DeclarableOp.h",
|
||||
"ops/declarable/DeclarableListOp.h",
|
||||
"ops/declarable/DeclarableReductionOp.h",
|
||||
|
|
|
@ -24,9 +24,11 @@ import org.junit.runners.Parameterized;
|
|||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
|
@ -316,6 +318,17 @@ public class BasicBroadcastTests extends BaseNd4jTest {
|
|||
assertEquals(exp, sum);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBroadcatableBool_1() {
|
||||
val op = DynamicCustomOp.builder("greater_equal")
|
||||
.addInputs(Nd4j.create(DataType.FLOAT, 3), Nd4j.create(DataType.FLOAT, 3))
|
||||
.build();
|
||||
|
||||
val l = op.calculateOutputShape();
|
||||
assertEquals(1, l.size());
|
||||
assertEquals(DataType.BOOL, l.get(0).dataType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -36,6 +36,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
|
@ -298,6 +301,40 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
|||
log.info("{} ns", ((timeEnd - timeStart) / (double) iterations));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWorkspaceOrder_1(){
|
||||
WorkspaceConfiguration conf = WorkspaceConfiguration.builder()
|
||||
.initialSize(1_000_000)
|
||||
.overallocationLimit(0.05)
|
||||
.policyLearning(LearningPolicy.NONE)
|
||||
.build();
|
||||
|
||||
val exp = Arrays.asList("outer", null, "outer", "inner", "outer", null);
|
||||
val res = new ArrayList<String>();
|
||||
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "outer")){
|
||||
try(MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(conf, "inner")){
|
||||
try(MemoryWorkspace ws3 = ws.notifyScopeBorrowed()){
|
||||
System.out.println("X: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
try(MemoryWorkspace ws4 = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
System.out.println("A: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null)
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
}
|
||||
System.out.println("B: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
}
|
||||
System.out.println("C: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //inner
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
}
|
||||
System.out.println("D: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //outer
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
}
|
||||
System.out.println("E: " + Nd4j.getMemoryManager().getCurrentWorkspace()); //None (null)
|
||||
res.add(Nd4j.getMemoryManager().getCurrentWorkspace() == null ? null : Nd4j.getMemoryManager().getCurrentWorkspace().getId());
|
||||
|
||||
assertEquals(exp, res);
|
||||
}
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
|
@ -616,19 +616,17 @@ public class WorkspaceProviderTests extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@Ignore("raver119: This test doesn't make any sense to me these days. We're borrowing from the same workspace. Why?")
|
||||
public void testNestedWorkspaces11() {
|
||||
for (int x = 1; x < 10; x++) {
|
||||
try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
||||
INDArray array1 = Nd4j.create(100 * x);
|
||||
|
||||
for (int i = 1; i < 10; i++) {
|
||||
try (MemoryWorkspace ws2 =
|
||||
Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
||||
try (MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
|
||||
INDArray array2 = Nd4j.create(100 * x);
|
||||
for (int e = 1; e < 10; e++) {
|
||||
try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager()
|
||||
.getWorkspaceForCurrentThread(basicConfiguration, "WS_1")
|
||||
.notifyScopeBorrowed()) {
|
||||
try (MemoryWorkspace ws3 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(basicConfiguration, "WS_1").notifyScopeBorrowed()) {
|
||||
INDArray array3 = Nd4j.create(100 * x);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue