[WIP] More fixes (#73)
* special tests for ConstantTadHelper/ConstantShapeHelper Signed-off-by: raver119 <raver119@gmail.com> * release methods for data buffers Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com> * delete temporary buffer Java side Signed-off-by: raver119 <raver119@gmail.com>master
parent
ce9c372974
commit
59a006ce29
|
@ -1699,6 +1699,7 @@ public:
|
||||||
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, double *data, int length);
|
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, double *data, int length);
|
||||||
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
||||||
|
|
||||||
|
void deleteShapeBuffer(Nd4jPointer ptr);
|
||||||
|
|
||||||
const char* runLightBenchmarkSuit(bool printOut);
|
const char* runLightBenchmarkSuit(bool printOut);
|
||||||
const char* runFullBenchmarkSuit(bool printOut);
|
const char* runFullBenchmarkSuit(bool printOut);
|
||||||
|
|
|
@ -2700,6 +2700,11 @@ nd4j::ConstantDataBuffer* NativeOps::shapeBuffer(int rank, Nd4jLong *shape, Nd4j
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) {
|
||||||
|
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
||||||
|
delete buffer;
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3241,6 +3241,11 @@ nd4j::ConstantDataBuffer* NativeOps::shapeBuffer(int rank, Nd4jLong *shape, Nd4j
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NativeOps::deleteShapeBuffer(Nd4jPointer ptr) {
|
||||||
|
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
||||||
|
delete buffer;
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
#include <ShapeDescriptor.h>
|
#include <ShapeDescriptor.h>
|
||||||
#include <array/ConstantDataBuffer.h>
|
#include <array/ConstantDataBuffer.h>
|
||||||
#include <memory/Workspace.h>
|
#include <memory/Workspace.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
|
@ -64,6 +65,31 @@ namespace nd4j {
|
||||||
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
|
Nd4jLong* createFromExisting(Nd4jLong *shapeInfo, bool destroyOriginal = true);
|
||||||
|
|
||||||
bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor);
|
bool checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of cached TAD shapes/offsets on specific device
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
FORCEINLINE int cachedEntriesForDevice(int deviceId) {
|
||||||
|
if (deviceId > _cache.size())
|
||||||
|
throw std::runtime_error("deviceId > number of actual devices");
|
||||||
|
|
||||||
|
return _cache[deviceId].size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns total number of cached TAD shapes/offsets on all devices
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
FORCEINLINE int totalCachedEntries() {
|
||||||
|
int total = 0;
|
||||||
|
|
||||||
|
for (int e = 0; e < _cache.size(); e++)
|
||||||
|
total += _cache[e].size();
|
||||||
|
|
||||||
|
return total;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#define DEV_TESTS_CONSTANTTADHELPER_H
|
#define DEV_TESTS_CONSTANTTADHELPER_H
|
||||||
|
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -45,11 +46,43 @@ namespace nd4j {
|
||||||
|
|
||||||
static ConstantTadHelper* getInstance();
|
static ConstantTadHelper* getInstance();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* These methods calculate Tensor-Along-Dimension(s) shape and offsets
|
||||||
|
*
|
||||||
|
* @param originalShape
|
||||||
|
* @param dimensions
|
||||||
|
* @param keepUnitiesInShape
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns number of cached TAD shapes/offsets on specific device
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
FORCEINLINE int cachedEntriesForDevice(int deviceId) {
|
||||||
|
if (deviceId > _cache.size())
|
||||||
|
throw std::runtime_error("deviceId > number of actual devices");
|
||||||
|
|
||||||
|
return _cache[deviceId].size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns total number of cached TAD shapes/offsets on all devices
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
FORCEINLINE int totalCachedEntries() {
|
||||||
|
int total = 0;
|
||||||
|
|
||||||
|
for (int e = 0; e < _cache.size(); e++)
|
||||||
|
total += _cache[e].size();
|
||||||
|
|
||||||
|
return total;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,42 @@ public:
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ConstantTadHelperTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ConstantShapeHelperTests, test_cachedAmount_1) {
|
||||||
|
auto ttlBefore = ConstantShapeHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
auto arrayA = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
|
||||||
|
auto ttlMiddle = ConstantShapeHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
auto arrayB = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
|
||||||
|
auto ttlAfter = ConstantShapeHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
ASSERT_TRUE(ttlBefore <= ttlMiddle);
|
||||||
|
ASSERT_EQ(ttlMiddle, ttlAfter);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConstantTadHelperTests, test_cachedAmount_1) {
|
||||||
|
auto arrayA = NDArrayFactory::create<bool>('c', {7, 11, 17, 23, 31, 43});
|
||||||
|
auto ttlBefore = ConstantTadHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
auto packAA = ConstantTadHelper::getInstance()->tadForDimensions(arrayA.shapeInfo(), {3, 4});
|
||||||
|
|
||||||
|
auto ttlMiddle = ConstantTadHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
auto packAB = ConstantTadHelper::getInstance()->tadForDimensions(arrayA.shapeInfo(), {3, 4});
|
||||||
|
|
||||||
|
auto ttlAfter = ConstantTadHelper::getInstance()->totalCachedEntries();
|
||||||
|
|
||||||
|
ASSERT_TRUE(ttlBefore <= ttlMiddle);
|
||||||
|
ASSERT_EQ(ttlMiddle, ttlAfter);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConstantShapeHelperTests, basic_test_1) {
|
TEST_F(ConstantShapeHelperTests, basic_test_1) {
|
||||||
auto ptr = ShapeBuilders::createShapeInfo(nd4j::DataType::BFLOAT16, 'f', {5, 10, 15});
|
auto ptr = ShapeBuilders::createShapeInfo(nd4j::DataType::BFLOAT16, 'f', {5, 10, 15});
|
||||||
ShapeDescriptor descriptor(ptr);
|
ShapeDescriptor descriptor(ptr);
|
||||||
|
|
|
@ -1,105 +1,105 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The CropAndResizeDataSetPreProcessor will crop and resize the processed dataset.
|
* The CropAndResizeDataSetPreProcessor will crop and resize the processed dataset.
|
||||||
* NOTE: The data format must be NHWC
|
* NOTE: The data format must be NHWC
|
||||||
*
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class CropAndResizeDataSetPreProcessor implements DataSetPreProcessor {
|
public class CropAndResizeDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
public enum ResizeMethod {
|
public enum ResizeMethod {
|
||||||
Bilinear,
|
Bilinear,
|
||||||
NearestNeighbor
|
NearestNeighbor
|
||||||
}
|
}
|
||||||
|
|
||||||
private final long[] resizedShape;
|
private final long[] resizedShape;
|
||||||
private final INDArray indices;
|
private final INDArray indices;
|
||||||
private final INDArray resize;
|
private final INDArray resize;
|
||||||
private final INDArray boxes;
|
private final INDArray boxes;
|
||||||
private final int method;
|
private final int method;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param originalHeight Height of the input datasets
|
* @param originalHeight Height of the input datasets
|
||||||
* @param originalWidth Width of the input datasets
|
* @param originalWidth Width of the input datasets
|
||||||
* @param cropYStart y coord of the starting point on the input datasets
|
* @param cropYStart y coord of the starting point on the input datasets
|
||||||
* @param cropXStart x coord of the starting point on the input datasets
|
* @param cropXStart x coord of the starting point on the input datasets
|
||||||
* @param resizedHeight Height of the output dataset
|
* @param resizedHeight Height of the output dataset
|
||||||
* @param resizedWidth Width of the output dataset
|
* @param resizedWidth Width of the output dataset
|
||||||
* @param numChannels
|
* @param numChannels
|
||||||
* @param resizeMethod
|
* @param resizeMethod
|
||||||
*/
|
*/
|
||||||
public CropAndResizeDataSetPreProcessor(int originalHeight, int originalWidth, int cropYStart, int cropXStart, int resizedHeight, int resizedWidth, int numChannels, ResizeMethod resizeMethod) {
|
public CropAndResizeDataSetPreProcessor(int originalHeight, int originalWidth, int cropYStart, int cropXStart, int resizedHeight, int resizedWidth, int numChannels, ResizeMethod resizeMethod) {
|
||||||
Preconditions.checkArgument(originalHeight > 0, "originalHeight must be greater than 0, got %s", originalHeight);
|
Preconditions.checkArgument(originalHeight > 0, "originalHeight must be greater than 0, got %s", originalHeight);
|
||||||
Preconditions.checkArgument(originalWidth > 0, "originalWidth must be greater than 0, got %s", originalWidth);
|
Preconditions.checkArgument(originalWidth > 0, "originalWidth must be greater than 0, got %s", originalWidth);
|
||||||
Preconditions.checkArgument(cropYStart >= 0, "cropYStart must be greater or equal to 0, got %s", cropYStart);
|
Preconditions.checkArgument(cropYStart >= 0, "cropYStart must be greater or equal to 0, got %s", cropYStart);
|
||||||
Preconditions.checkArgument(cropXStart >= 0, "cropXStart must be greater or equal to 0, got %s", cropXStart);
|
Preconditions.checkArgument(cropXStart >= 0, "cropXStart must be greater or equal to 0, got %s", cropXStart);
|
||||||
Preconditions.checkArgument(resizedHeight > 0, "resizedHeight must be greater than 0, got %s", resizedHeight);
|
Preconditions.checkArgument(resizedHeight > 0, "resizedHeight must be greater than 0, got %s", resizedHeight);
|
||||||
Preconditions.checkArgument(resizedWidth > 0, "resizedWidth must be greater than 0, got %s", resizedWidth);
|
Preconditions.checkArgument(resizedWidth > 0, "resizedWidth must be greater than 0, got %s", resizedWidth);
|
||||||
Preconditions.checkArgument(numChannels > 0, "numChannels must be greater than 0, got %s", numChannels);
|
Preconditions.checkArgument(numChannels > 0, "numChannels must be greater than 0, got %s", numChannels);
|
||||||
|
|
||||||
resizedShape = new long[] { 1, resizedHeight, resizedWidth, numChannels };
|
resizedShape = new long[] { 1, resizedHeight, resizedWidth, numChannels };
|
||||||
|
|
||||||
boxes = Nd4j.create(new float[] {
|
boxes = Nd4j.create(new float[] {
|
||||||
(float)cropYStart / (float)originalHeight,
|
(float)cropYStart / (float)originalHeight,
|
||||||
(float)cropXStart / (float)originalWidth,
|
(float)cropXStart / (float)originalWidth,
|
||||||
(float)(cropYStart + resizedHeight) / (float)originalHeight,
|
(float)(cropYStart + resizedHeight) / (float)originalHeight,
|
||||||
(float)(cropXStart + resizedWidth) / (float)originalWidth
|
(float)(cropXStart + resizedWidth) / (float)originalWidth
|
||||||
}, new long[] { 1, 4 }, DataType.FLOAT);
|
}, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
indices = Nd4j.create(new int[] { 0 }, new long[] { 1, 1 }, DataType.INT);
|
indices = Nd4j.create(new int[] { 0 }, new long[] { 1, 1 }, DataType.INT);
|
||||||
|
|
||||||
resize = Nd4j.create(new int[] { resizedHeight, resizedWidth }, new long[] { 1, 2 }, DataType.INT);
|
resize = Nd4j.create(new int[] { resizedHeight, resizedWidth }, new long[] { 1, 2 }, DataType.INT);
|
||||||
method = resizeMethod == ResizeMethod.Bilinear ? 0 : 1;
|
method = resizeMethod == ResizeMethod.Bilinear ? 0 : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* NOTE: The data format must be NHWC
|
* NOTE: The data format must be NHWC
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void preProcess(DataSet dataSet) {
|
public void preProcess(DataSet dataSet) {
|
||||||
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
if(dataSet.isEmpty()) {
|
if(dataSet.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray input = dataSet.getFeatures();
|
INDArray input = dataSet.getFeatures();
|
||||||
INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false);
|
INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false);
|
||||||
|
|
||||||
CustomOp op = DynamicCustomOp.builder("crop_and_resize")
|
CustomOp op = DynamicCustomOp.builder("crop_and_resize")
|
||||||
.addInputs(input, boxes, indices, resize)
|
.addInputs(input, boxes, indices, resize)
|
||||||
.addIntegerArguments(method)
|
.addIntegerArguments(method)
|
||||||
.addOutputs(output)
|
.addOutputs(output)
|
||||||
.build();
|
.build();
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
|
||||||
dataSet.setFeatures(output);
|
dataSet.setFeatures(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,87 +1,87 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The PermuteDataSetPreProcessor will rearrange the dimensions.
|
* The PermuteDataSetPreProcessor will rearrange the dimensions.
|
||||||
* There are two pre-defined permutation types:
|
* There are two pre-defined permutation types:
|
||||||
* - from NCHW to NHWC
|
* - from NCHW to NHWC
|
||||||
* - from NHWC to NCHW
|
* - from NHWC to NCHW
|
||||||
*
|
*
|
||||||
* Or, pass the new order to the ctor. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
* Or, pass the new order to the ctor. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
||||||
*
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class PermuteDataSetPreProcessor implements DataSetPreProcessor {
|
public class PermuteDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
private final PermutationTypes permutationType;
|
private final PermutationTypes permutationType;
|
||||||
private final int[] rearrange;
|
private final int[] rearrange;
|
||||||
|
|
||||||
public enum PermutationTypes { NCHWtoNHWC, NHWCtoNCHW, Custom }
|
public enum PermutationTypes { NCHWtoNHWC, NHWCtoNCHW, Custom }
|
||||||
|
|
||||||
public PermuteDataSetPreProcessor(PermutationTypes permutationType) {
|
public PermuteDataSetPreProcessor(PermutationTypes permutationType) {
|
||||||
Preconditions.checkArgument(permutationType != PermutationTypes.Custom, "Use the ctor PermuteDataSetPreProcessor(int... rearrange) for custom permutations.");
|
Preconditions.checkArgument(permutationType != PermutationTypes.Custom, "Use the ctor PermuteDataSetPreProcessor(int... rearrange) for custom permutations.");
|
||||||
|
|
||||||
this.permutationType = permutationType;
|
this.permutationType = permutationType;
|
||||||
rearrange = null;
|
rearrange = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param rearrange The new order. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
* @param rearrange The new order. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
||||||
*/
|
*/
|
||||||
public PermuteDataSetPreProcessor(int... rearrange) {
|
public PermuteDataSetPreProcessor(int... rearrange) {
|
||||||
|
|
||||||
this.permutationType = PermutationTypes.Custom;
|
this.permutationType = PermutationTypes.Custom;
|
||||||
this.rearrange = rearrange;
|
this.rearrange = rearrange;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preProcess(DataSet dataSet) {
|
public void preProcess(DataSet dataSet) {
|
||||||
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
if(dataSet.isEmpty()) {
|
if(dataSet.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray input = dataSet.getFeatures();
|
INDArray input = dataSet.getFeatures();
|
||||||
INDArray output;
|
INDArray output;
|
||||||
switch (permutationType) {
|
switch (permutationType) {
|
||||||
case NCHWtoNHWC:
|
case NCHWtoNHWC:
|
||||||
output = input.permute(0, 2, 3, 1);
|
output = input.permute(0, 2, 3, 1);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case NHWCtoNCHW:
|
case NHWCtoNCHW:
|
||||||
output = input.permute(0, 3, 1, 2);
|
output = input.permute(0, 3, 1, 2);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case Custom:
|
case Custom:
|
||||||
output = input.permute(rearrange);
|
output = input.permute(rearrange);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
output = input;
|
output = input;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSet.setFeatures(output);
|
dataSet.setFeatures(output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,70 +1,70 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The RGBtoGrayscaleDataSetPreProcessor will turn a DataSet of a RGB image into a grayscale one.
|
* The RGBtoGrayscaleDataSetPreProcessor will turn a DataSet of a RGB image into a grayscale one.
|
||||||
* NOTE: Expects data format to be NCHW. After processing, the channel dimension is eliminated. (NCHW -> NHW)
|
* NOTE: Expects data format to be NCHW. After processing, the channel dimension is eliminated. (NCHW -> NHW)
|
||||||
*
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class RGBtoGrayscaleDataSetPreProcessor implements DataSetPreProcessor {
|
public class RGBtoGrayscaleDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
private static final float RED_RATIO = 0.3f;
|
private static final float RED_RATIO = 0.3f;
|
||||||
private static final float GREEN_RATIO = 0.59f;
|
private static final float GREEN_RATIO = 0.59f;
|
||||||
private static final float BLUE_RATIO = 0.11f;
|
private static final float BLUE_RATIO = 0.11f;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preProcess(DataSet dataSet) {
|
public void preProcess(DataSet dataSet) {
|
||||||
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
if(dataSet.isEmpty()) {
|
if(dataSet.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray originalFeatures = dataSet.getFeatures();
|
INDArray originalFeatures = dataSet.getFeatures();
|
||||||
long[] originalShape = originalFeatures.shape();
|
long[] originalShape = originalFeatures.shape();
|
||||||
|
|
||||||
// result shape is NHW
|
// result shape is NHW
|
||||||
INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]);
|
INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]);
|
||||||
|
|
||||||
for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) {
|
for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) {
|
||||||
// Extract channels
|
// Extract channels
|
||||||
INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW
|
INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW
|
||||||
INDArray R = itemFeatures.slice(0, 0); // shape is HW
|
INDArray R = itemFeatures.slice(0, 0); // shape is HW
|
||||||
INDArray G = itemFeatures.slice(1, 0);
|
INDArray G = itemFeatures.slice(1, 0);
|
||||||
INDArray B = itemFeatures.slice(2, 0);
|
INDArray B = itemFeatures.slice(2, 0);
|
||||||
|
|
||||||
// Convert
|
// Convert
|
||||||
R.muli(RED_RATIO);
|
R.muli(RED_RATIO);
|
||||||
G.muli(GREEN_RATIO);
|
G.muli(GREEN_RATIO);
|
||||||
B.muli(BLUE_RATIO);
|
B.muli(BLUE_RATIO);
|
||||||
R.addi(G).addi(B);
|
R.addi(G).addi(B);
|
||||||
|
|
||||||
// FIXME: int cast
|
// FIXME: int cast
|
||||||
result.putSlice((int)n, R);
|
result.putSlice((int)n, R);
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSet.setFeatures(result);
|
dataSet.setFeatures(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1120,6 +1120,8 @@ public abstract class NativeOps extends Pointer {
|
||||||
// GraphState creation
|
// GraphState creation
|
||||||
public abstract Pointer getGraphState(long id);
|
public abstract Pointer getGraphState(long id);
|
||||||
|
|
||||||
|
public abstract void deleteShapeBuffer(Pointer state);
|
||||||
|
|
||||||
public abstract void deleteGraphState(Pointer state);
|
public abstract void deleteGraphState(Pointer state);
|
||||||
|
|
||||||
public abstract int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
|
public abstract int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
|
||||||
|
|
|
@ -2586,7 +2586,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
return new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length));
|
val result = new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length));
|
||||||
|
|
||||||
|
nativeOps.deleteShapeBuffer(dbf);
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -3047,6 +3047,7 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||||
|
|
||||||
|
public native void deleteShapeBuffer(@Cast("Nd4jPointer") Pointer ptr);
|
||||||
|
|
||||||
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
|
@ -3698,18 +3699,18 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
private native void allocate(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data);
|
private native void allocate(byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @StdVector double[] data);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using given buffer (without memory allocating) and shape information stored in shape
|
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
|
||||||
*/
|
*/
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/);
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongPointer shape, @Cast("nd4j::DataType") int dtype);
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/);
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector LongBuffer shape, @Cast("nd4j::DataType") int dtype);
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/) { super((Pointer)null); allocate(buffer, order, shape, dtype, context, isBuffAlloc); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype, LaunchContext context/*=nd4j::LaunchContext::defaultContext()*/, @Cast("const bool") boolean isBuffAlloc/*=false*/);
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
||||||
|
|
||||||
|
@ -8034,9 +8035,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||||
|
|
||||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
|
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
|
@ -8946,6 +8948,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
// #endif /* SHAPE_H_ */
|
// #endif /* SHAPE_H_ */
|
||||||
|
|
||||||
|
|
||||||
// Parsed from array/ShapeList.h
|
// Parsed from array/ShapeList.h
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
|
|
|
@ -2162,7 +2162,11 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
return new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length));
|
val result = new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length));
|
||||||
|
|
||||||
|
loop.deleteShapeBuffer(dbf);
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue