createUninitializedDetached refactoring. (#94)
* wip * update interface, add null implementations. * Breaking one test in a weird way. Signed-off-by: Robert Altena <Rob@Ra-ai.com> * createUninitializedDetached refactored. Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
cd41c3540d
commit
dfec54242d
|
@ -135,7 +135,7 @@ public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.
|
|||
if(conf.getLastEpochSinceRUpdated() == 0 && epochCount == 0) {
|
||||
INDArray currentR = doOutput(false,workspaceMgr);
|
||||
if(window == null) {
|
||||
window = Nd4j.createUninitializedDetached(conf.getWindowSize()).assign(0.0);
|
||||
window = Nd4j.createUninitializedDetached(preOut.dataType(), conf.getWindowSize()).assign(0.0);
|
||||
}
|
||||
|
||||
if(batchWindowSizeIndex < window.length() - currentR.length()) {
|
||||
|
|
|
@ -1096,14 +1096,13 @@ public interface NDArrayFactory {
|
|||
INDArray createUninitialized(DataType dataType, long[] shape, char ordering, MemoryWorkspace workspace);
|
||||
|
||||
/**
|
||||
* Cretes uninitialized INDArray detached from any (if any) workspace
|
||||
* @param shape
|
||||
* @param ordering
|
||||
* @return
|
||||
* Create an uninitialized ndArray. Detached from workspace.
|
||||
* @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
|
||||
* @param ordering Fortran 'f' or C/C++ 'c' ordering.
|
||||
* @param shape the shape of the array.
|
||||
* @return the created detached array.
|
||||
*/
|
||||
INDArray createUninitializedDetached(int[] shape, char ordering);
|
||||
|
||||
INDArray createUninitializedDetached(long[] shape, char ordering);
|
||||
INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -4907,17 +4907,6 @@ public class Nd4j {
|
|||
return INSTANCE.createUninitialized(shape, ordering);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createUninitialized(long[], char)}
|
||||
*/
|
||||
public static INDArray createUninitializedDetached(int[] shape, char ordering) {
|
||||
if (shape.length == 0)
|
||||
return scalar(dataType(), 0.0);
|
||||
|
||||
checkShapeValues(shape);
|
||||
return INSTANCE.createUninitializedDetached(shape, ordering);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createUninitialized(long[])}
|
||||
*/
|
||||
|
@ -4963,13 +4952,24 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
/**
|
||||
* See {@link #createUninitialized(long)}
|
||||
* Create an uninitialized ndArray. Detached from workspace.
|
||||
* @param dataType data type. Exceptions will be thrown for UTF8, COMPRESSED and UNKNOWN data types.
|
||||
* @param ordering Fortran 'f' or C/C++ 'c' ordering.
|
||||
* @param shape the shape of the array.
|
||||
* @return the created detached array.
|
||||
*/
|
||||
public static INDArray createUninitializedDetached(int length) {
|
||||
long[] shape = new long[] {length};
|
||||
return INSTANCE.createUninitializedDetached(shape, order());
|
||||
public static INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
|
||||
return INSTANCE.createUninitializedDetached(dataType, ordering, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #createUninitializedDetached(DataType, char, long...)} with default ordering.
|
||||
*/
|
||||
public static INDArray createUninitializedDetached(DataType dataType, long... shape){
|
||||
return createUninitializedDetached(dataType, order(), shape);
|
||||
}
|
||||
|
||||
|
||||
////////////////////// OTHER ///////////////////////////////
|
||||
|
||||
/**
|
||||
|
|
|
@ -199,19 +199,10 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
}
|
||||
|
||||
@Override
|
||||
public INDArray createUninitializedDetached(int[] shape, char ordering) {
|
||||
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
|
||||
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
||||
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray createUninitializedDetached(long[] shape, char ordering) {
|
||||
MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
||||
INDArray ret = new NDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
|
||||
INDArray ret = new NDArray(dataType, shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -222,7 +222,7 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
|
|||
}
|
||||
|
||||
@Override
|
||||
public INDArray createUninitializedDetached(long[] shape, char ordering) {
|
||||
public INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -427,11 +427,6 @@ public class CpuSparseNDArrayFactory extends BaseSparseNDArrayFactory {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray createUninitializedDetached(int[] shape, char ordering) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray create(DataBuffer data, int[] newShape, int[] newStride, long offset, char ordering) {
|
||||
return null;
|
||||
|
|
|
@ -312,7 +312,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
INDArray array1 = Nd4j.create(new double[] {1f, 2f, 3f, 4f, 5f});
|
||||
|
||||
INDArray array2 = Nd4j.createUninitializedDetached(5);
|
||||
INDArray array2 = Nd4j.createUninitializedDetached(DOUBLE, 5);
|
||||
|
||||
array2.assign(array1);
|
||||
|
||||
|
|
Loading…
Reference in New Issue