createUninitializedDetached refactoring. (#94)

* wip

* update interface, add null implementations.

* Breaking one test in a weird way.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* createUninitializedDetached refactored.

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-02 14:13:00 +09:00 committed by AlexDBlack
parent cd41c3540d
commit dfec54242d
6 changed files with 27 additions and 42 deletions

View File

@ -135,7 +135,7 @@ public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.
if(conf.getLastEpochSinceRUpdated() == 0 && epochCount == 0) { if(conf.getLastEpochSinceRUpdated() == 0 && epochCount == 0) {
INDArray currentR = doOutput(false,workspaceMgr); INDArray currentR = doOutput(false,workspaceMgr);
if(window == null) { if(window == null) {
window = Nd4j.createUninitializedDetached(conf.getWindowSize()).assign(0.0); window = Nd4j.createUninitializedDetached(preOut.dataType(), conf.getWindowSize()).assign(0.0);
} }
if(batchWindowSizeIndex < window.length() - currentR.length()) { if(batchWindowSizeIndex < window.length() - currentR.length()) {

View File

@ -1096,14 +1096,13 @@ public interface NDArrayFactory {
INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace); INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace);
/** /**
* Cretes uninitialized INDArray detached from any (if any) workspace * Create an uninitialized ndArray. Detached from workspace.
* @param shape * @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
* @param ordering * @param ordering Fortran 'f' or C/C++ 'c' ordering.
* @return * @param shape the shape of the array.
* @return the created detached array.
*/ */
INDArray createUninitializedDetached(int[] shape, char ordering); INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape);
INDArray createUninitializedDetached(long[] shape, char ordering);
/** /**
* *

View File

@ -4907,17 +4907,6 @@ public class Nd4j {
return INSTANCE.createUninitialized(shape, ordering); return INSTANCE.createUninitialized(shape, ordering);
} }
/**
* See {@link #createUninitialized(long[], char)}
*/
public static INDArray createUninitializedDetached(int[] shape, char ordering) {
if (shape.length == 0)
return scalar(dataType(), 0.0);
checkShapeValues(shape);
return INSTANCE.createUninitializedDetached(shape, ordering);
}
/** /**
* See {@link #createUninitialized(long[])} * See {@link #createUninitialized(long[])}
*/ */
@ -4963,13 +4952,24 @@ public class Nd4j {
} }
/** /**
* See {@link #createUninitialized(long)} * Create an uninitialized ndArray. Detached from workspace.
* @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
* @param ordering Fortran 'f' or C/C++ 'c' ordering.
* @param shape the shape of the array.
* @return the created detached array.
*/ */
public static INDArray createUninitializedDetached(int length) { public static INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
long[] shape = new long[] {length}; return INSTANCE.createUninitializedDetached(dataType, ordering, shape);
return INSTANCE.createUninitializedDetached(shape, order());
} }
/**
* See {@link #createUninitializedDetached(DataType, char, long...)} with default ordering.
*/
public static INDArray createUninitializedDetached(DataType dataType, long... shape){
return createUninitializedDetached(dataType, order(), shape);
}
////////////////////// OTHER /////////////////////////////// ////////////////////// OTHER ///////////////////////////////
/** /**

View File

@ -199,19 +199,10 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
} }
@Override @Override
public INDArray createUninitializedDetached(int[] shape, char ordering) { public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null); Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false); INDArray ret = new NDArray(dataType, shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret;
}
@Override
public INDArray createUninitializedDetached(long[] shape, char ordering) {
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace); Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
return ret; return ret;
} }

View File

@ -222,7 +222,7 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
} }
@Override @Override
public INDArray createUninitializedDetached(long[] shape, char ordering) { public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
return null; return null;
} }
@ -427,11 +427,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
return null; return null;
} }
@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
return null;
}
@Override @Override
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) { public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
return null; return null;

View File

@ -312,7 +312,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f}); INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f});
INDArray array2 = Nd4j.createUninitializedDetached(5); INDArray array2 = Nd4j.createUninitializedDetached(DOUBLE, 5);
array2.assign(array1); array2.assign(array1);