createUninitializedDetached refactoring. (#94)

* wip

* update interface, add null implementations.

* Breaking one test in a weird way.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* createUninitializedDetached refactored.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-02 14:13:00 +09:00 committed by AlexDBlack
parent cd41c3540d
commit dfec54242d
6 changed files with 27 additions and 42 deletions

View File

@ -135,7 +135,7 @@ public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.
if(conf.getLastEpochSinceRUpdated() == 0 && epochCount == 0) {
INDArray currentR = doOutput(false,workspaceMgr);
if(window == null) {
window = Nd4j.createUninitializedDetached(conf.getWindowSize()).assign(0.0);
window = Nd4j.createUninitializedDetached(preOut.dataType(), conf.getWindowSize()).assign(0.0);
}
if(batchWindowSizeIndex < window.length() - currentR.length()) {

View File

@ -1096,14 +1096,13 @@ public interface NDArrayFactory {
INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace);
/**
* Cretes uninitialized INDArray detached from any (if any) workspace
* @param shape
* @param ordering
* @return
* Create an uninitialized ndArray. Detached from workspace.
* @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
* @param ordering Fortran 'f' or C/C++ 'c' ordering.
* @param shape the shape of the array.
* @return the created detached array.
*/
INDArray createUninitializedDetached(int[] shape, char ordering);
INDArray createUninitializedDetached(long[] shape, char ordering);
INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape);
/**
*

View File

@ -4907,17 +4907,6 @@ public class Nd4j {
return INSTANCE.createUninitialized(shape, ordering);
}
/**
* See {@link #createUninitialized(long[], char)}
*/
public static INDArray createUninitializedDetached(int[] shape, char ordering) {
if (shape.length == 0)
return scalar(dataType(), 0.0);
checkShapeValues(shape);
return INSTANCE.createUninitializedDetached(shape, ordering);
}
/**
* See {@link #createUninitialized(long[])}
*/
@ -4963,13 +4952,24 @@ public class Nd4j {
}
/**
* See {@link #createUninitialized(long)}
* Create an uninitialized ndArray. Detached from workspace.
* @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
* @param ordering Fortran 'f' or C/C++ 'c' ordering.
* @param shape the shape of the array.
* @return the created detached array.
*/
public static INDArray createUninitializedDetached(int length) {
long[] shape = new long[] {length};
return INSTANCE.createUninitializedDetached(shape, order());
public static INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
return INSTANCE.createUninitializedDetached(dataType, ordering, shape);
}
/**
* See {@link #createUninitializedDetached(DataType, char, long...)} with default ordering.
*/
public static INDArray createUninitializedDetached(DataType dataType, long... shape){
return createUninitializedDetached(dataType, order(), shape);
}
////////////////////// OTHER ///////////////////////////////
/**

View File

@ -199,19 +199,10 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
}
@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret;
}
@Override
public INDArray createUninitializedDetached(long[] shape, char ordering) {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
INDArray ret = new NDArray(dataType, shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret;
}

View File

@ -222,7 +222,7 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
}
@Override
public INDArray createUninitializedDetached(long[] shape, char ordering) {
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
return null;
}
@ -427,11 +427,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
return null;
}
@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
return null;
}
@Override
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
return null;

View File

@ -312,7 +312,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f});
INDArray array2 = Nd4j.createUninitializedDetached(5);
INDArray array2 = Nd4j.createUninitializedDetached(DOUBLE, 5);
array2.assign(array1);