Nd4j refactoring (#101)

* cleanup

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-06 16:40:08 +09:00 committed by GitHub
parent b8846113bd
commit a438434b1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 68 additions and 168 deletions

View File

@ -54,7 +54,6 @@ import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
@ -90,7 +89,6 @@ import org.nd4j.tools.PropertyParser;
import org.nd4j.versioncheck.VersionCheck;
import java.io.*;
import java.lang.ref.ReferenceQueue;
import java.lang.reflect.Constructor;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
@ -114,29 +112,29 @@ import java.util.logging.Logger;
*/
public class Nd4j {
public final static String DATA_BUFFER_OPS = "databufferfactory";
public final static String CONVOLUTION_OPS = "convops";
private final static String DATA_BUFFER_OPS = "databufferfactory";
private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
@Deprecated
public final static String DTYPE = ND4JSystemProperties.DTYPE;
public final static String BLAS_OPS = "blas.ops";
public final static String SPARSE_BLAS_OPS = "sparseblas.ops";
private final static String BLAS_OPS = "blas.ops";
private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
public final static String NATIVE_OPS = "native.ops";
public final static String ORDER_KEY = "ndarray.order";
public final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
public final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
public final static String OP_EXECUTIONER = "opexec";
public final static String OP_FACTORY = "opfactory";
private final static String ORDER_KEY = "ndarray.order";
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
private final static String OP_EXECUTIONER = "opexec";
public final static String DISTRIBUTION = "dist";
public final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
public final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
public final static String CONSTANT_PROVIDER = "constantsprovider";
public final static String AFFINITY_MANAGER = "affinitymanager";
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
private final static String CONSTANT_PROVIDER = "constantsprovider";
private final static String AFFINITY_MANAGER = "affinitymanager";
//disable toString() on compressed arrays for debugging. Should be off by default.
public final static String COMPRESSION_DEBUG = "compressiondebug";
public final static String MEMORY_MANAGER = "memorymanager";
public final static String WORKSPACE_MANAGER = "workspacemanager";
public final static String RANDOM_PROVIDER = "random";
private final static String COMPRESSION_DEBUG = "compressiondebug";
private final static String MEMORY_MANAGER = "memorymanager";
private final static String WORKSPACE_MANAGER = "workspacemanager";
private final static String RANDOM_PROVIDER = "random";
/**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/
@Deprecated
public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION;
@ -144,8 +142,7 @@ public class Nd4j {
//the datatype used for allocating buffers
protected static DataType dtype = DataType.FLOAT;
//the allocation mode for the heap
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.HEAP;
public static char ORDER = 'c';
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
public static double EPS_THRESHOLD = 1e-5;
private static boolean allowsOrder = false;
public static boolean compressDebug = false;
@ -157,45 +154,27 @@ public class Nd4j {
private static final AtomicInteger numThreads = new AtomicInteger(-1);
private static AtomicReference<DataType> defaultFloatingPointDataType;
protected static Class<? extends MemoryWorkspaceManager> workspaceManagerClazz;
protected static Class<? extends BlasWrapper> blasWrapperClazz;
protected static Class<? extends BlasWrapper> sparseBlasWrapperClazz;
protected static Class<? extends NDArrayFactory> ndArrayFactoryClazz;
protected static Class<? extends NDArrayFactory> sparseNDArrayClazz;
protected static Class<? extends ConvolutionInstance> convolutionInstanceClazz;
protected static Class<? extends DataBufferFactory> dataBufferFactoryClazz;
protected static Class<? extends OpExecutioner> opExecutionerClazz;
protected static Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz;
protected static Class<? extends DistributionFactory> distributionFactoryClazz;
protected static Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz;
protected static Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz;
protected static Class<? extends BasicConstantHandler> constantProviderClazz;
protected static Class<? extends BasicAffinityManager> affinityManagerClazz;
protected static Class<? extends BasicMemoryManager> memoryManagerClazz;
protected static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
protected static BlasWrapper BLAS_WRAPPER_INSTANCE;
protected static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
private static BlasWrapper BLAS_WRAPPER_INSTANCE;
private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
protected static NDArrayFactory INSTANCE;
protected static NDArrayFactory SPARSE_INSTANCE;
protected static ConvolutionInstance CONVOLUTION_INSTANCE;
protected static OpExecutioner OP_EXECUTIONER_INSTANCE;
protected static DistributionFactory DISTRIBUTION_FACTORY;
protected static ShapeInfoProvider shapeInfoProvider;
protected static SparseInfoProvider sparseInfoProvider;
protected static ConstantHandler constantHandler;
protected static AffinityManager affinityManager;
protected static MemoryManager memoryManager;
private static NDArrayFactory SPARSE_INSTANCE;
private static ConvolutionInstance CONVOLUTION_INSTANCE;
private static OpExecutioner OP_EXECUTIONER_INSTANCE;
private static DistributionFactory DISTRIBUTION_FACTORY;
private static ShapeInfoProvider shapeInfoProvider;
private static SparseInfoProvider sparseInfoProvider;
private static ConstantHandler constantHandler;
private static AffinityManager affinityManager;
private static MemoryManager memoryManager;
protected static AtomicBoolean fallbackMode;
private static AtomicBoolean fallbackMode;
protected static Properties props = new Properties();
protected static ReferenceQueue<INDArray> referenceQueue = new ReferenceQueue<>();
protected static ReferenceQueue<DataBuffer> bufferQueue = new ReferenceQueue<>();
private final static Logger logger = Logger.getLogger(Nd4j.class.getName());
protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
private static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
static {
fallbackMode = new AtomicBoolean(false);
@ -385,7 +364,6 @@ public class Nd4j {
* @param toShuffle the ndarray to shuffle
* @param random the random to use
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
INSTANCE.shuffle(toShuffle, random, dimension);
@ -396,10 +374,8 @@ public class Nd4j {
* along a specified set of dimensions
* @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension);
}
@ -408,10 +384,8 @@ public class Nd4j {
* along a specified set of dimensions
* @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension);
}
@ -420,10 +394,8 @@ public class Nd4j {
* along a specified set of dimensions
* @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle
* @return
*/
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, rnd, dimension);
}
@ -433,33 +405,11 @@ public class Nd4j {
*
* @param toShuffle the ndarray to shuffle
* @param dimensions the dimension to do the shuffle. Please note - order matters here.
* @return
*/
public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) {
INSTANCE.shuffle(toShuffle, rnd, dimensions);
}
/**
* The reference queue used for cleaning up
* ndarrays
*
* @return the reference queue for cleaning up ndarrays
*/
public static ReferenceQueue<INDArray> refQueue() {
return referenceQueue;
}
/**
* The reference queue used for cleaning up
* databuffers
*
* @return the reference queue for cleaning up databuffers
*/
public static ReferenceQueue<DataBuffer> bufferRefQueue() {
return bufferQueue;
}
/**
* Get the primary distributions
* factory
@ -480,9 +430,9 @@ public class Nd4j {
}
/**
* This method returns RandomFactory instance
* Get the RandomFactory instance
*
* @return
* @return the RandomFactory instance
*/
public static RandomFactory getRandomFactory() {
return randomFactory;
@ -500,7 +450,7 @@ public class Nd4j {
/**
* Set a convolution instance
*
* @param convolutionInstance
* @param convolutionInstance the new convolution instance
*/
public static void setConvolution(ConvolutionInstance convolutionInstance) {
if (convolutionInstance == null)
@ -526,7 +476,6 @@ public class Nd4j {
* slice is the specified shape
*/
public static INDArray create(int[] sliceShape, float[]... arrays) {
//TODO: Remove duplicate code.
int slices = arrays.length;
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
@ -586,30 +535,6 @@ public class Nd4j {
return DATA_BUFFER_FACTORY_INSTANCE;
}
/**
* Given a sequence of Iterators over a transform of matrices, fill in all of
* the matrices with the entries in the theta vector. Errors are
* thrown if the theta vector does not exactly fill the matrices.
*
* TODO: unused method.
*/
public static void setParams(INDArray theta, Collection<INDArray>... matrices) {
int index = 0;
for (Collection<INDArray> matrixCollection : matrices) {
for (INDArray matrix : matrixCollection) {
INDArray linear = matrix.reshape(-1);
for (int i = 0; i < matrix.length(); i++) {
linear.putScalar(i, theta.getDouble(index));
index++;
}
}
}
if (index != theta.length()) {
throw new AssertionError("Did not entirely use the theta vector");
}
}
/**
* Roll the specified axis backwards,
* until it lies in a given position.
@ -681,7 +606,7 @@ public class Nd4j {
* @param b the right tensor
* @param result the result array
* @param axes the axes for each array to do matrix multiply along
* @return
* @return the result array
*/
public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) {
int validationLength = Math.min(axes[0].length, axes[1].length);
@ -720,15 +645,7 @@ public class Nd4j {
//if listA and listB are empty these donot initialize.
//so initializing with {1} which will then get overriden if not empty
long[] newShapeA = {-1, n2};
//TODO: remove duplicate code.
long[] oldShapeA;
if (listA.size() == 0) {
oldShapeA = new long[] {1};
} else {
oldShapeA = Longs.toArray(listA);
for (int i = 0; i < oldShapeA.length; i++)
oldShapeA[i] = a.size((int) oldShapeA[i]);
}
long[] oldShapeA = getOldShape(listA, a);
int n3 = 1;
int bNax = Math.min(b.rank(), axes[1].length);
@ -737,14 +654,7 @@ public class Nd4j {
}
long[] newShapeB = {n3, -1};
long[] oldShapeB;
if (listB.size() == 0) {
oldShapeB = new long[] {1};
} else {
oldShapeB = Longs.toArray(listB);
for (int i = 0; i < oldShapeB.length; i++)
oldShapeB[i] = b.size((int) oldShapeB[i]);
}
long[] oldShapeB = getOldShape(listB, b);
INDArray at = a.permute(newAxesA).reshape(newShapeA);
INDArray bt = b.permute(newAxesB).reshape(newShapeB);
@ -754,6 +664,19 @@ public class Nd4j {
return ret.reshape(aPlusB);
}
// Some duplicate code that refactored out:
private static long[] getOldShape(List<Integer> list, INDArray x){
long[] res;
if (list.size() == 0) {
res = new long[] {1};
} else {
res= Longs.toArray(list);
for (int i = 0; i < res.length; i++)
res[i] = x.size((int) res[i]);
}
return res;
}
/**
* Tensor matrix multiplication.
* Both tensors must be the same rank
@ -761,7 +684,7 @@ public class Nd4j {
* @param a the left tensor
* @param b the right tensor
* @param axes the axes for each array to do matrix multiply along
* @return
* @return the multiplication result.
*/
public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) {
CustomOp op = DynamicCustomOp.builder("tensordot")
@ -892,29 +815,6 @@ public class Nd4j {
return matmul(a,b, null);
}
/**
* Given a sequence of Iterators over a transform of matrices, fill in all of
* the matrices with the entries in the theta vector. Errors are
* thrown if the theta vector does not exactly fill the matrices.
* TODO: unused method.
*/
public static void setParams(INDArray theta, Iterator<? extends INDArray>... matrices) {
int index = 0;
for (Iterator<? extends INDArray> matrixIterator : matrices) {
while (matrixIterator.hasNext()) {
INDArray matrix = matrixIterator.next().reshape(-1);
for (int i = 0; i < matrix.length(); i++) {
matrix.putScalar(i, theta.getDouble(index));
index++;
}
}
}
if (index != theta.length()) {
throw new AssertionError("Did not entirely use the theta vector");
}
}
/**
* The factory used for creating ndarrays
*
@ -5766,45 +5666,45 @@ public class Nd4j {
}
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
Class<? extends BasicAffinityManager> affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
.forName(pp.toString(AFFINITY_MANAGER));
affinityManager = affinityManagerClazz.newInstance();
ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(NDARRAY_FACTORY_CLASS));
sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
.forName(pp.toString(SHAPEINFO_PROVIDER));
sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
pp.toString(SPARSEINFO_PROVIDER));
constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
.forName(pp.toString(CONSTANT_PROVIDER));
memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
Class<? extends BasicMemoryManager> memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
.forName(pp.toString(MEMORY_MANAGER));
allowsOrder = backend.allowsOrder();
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
randomFactory = new RandomFactory(randomClazz);
workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
.forName(pp.toString(WORKSPACE_MANAGER));
blasWrapperClazz = (Class<? extends BlasWrapper>) Class
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(BLAS_OPS));
sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(SPARSE_BLAS_OPS));
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
memoryManager = memoryManagerClazz.newInstance();
@ -5813,7 +5713,7 @@ public class Nd4j {
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
workspaceManager = workspaceManagerClazz.newInstance();
opExecutionerClazz = (Class<? extends OpExecutioner>) Class
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();