From a438434b1f939a11f947d7726157afb7a7a1ffa3 Mon Sep 17 00:00:00 2001 From: Robert Altena Date: Tue, 6 Aug 2019 16:40:08 +0900 Subject: [PATCH] Nd4j refactoring (#101) * cleanup Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena * wip Signed-off-by: Robert Altena --- .../java/org/nd4j/linalg/factory/Nd4j.java | 236 +++++------------- 1 file changed, 68 insertions(+), 168 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 5c8275ef3..07aa9cdc3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -54,7 +54,6 @@ import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; -import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.DiagPart; @@ -90,7 +89,6 @@ import org.nd4j.tools.PropertyParser; import org.nd4j.versioncheck.VersionCheck; import java.io.*; -import java.lang.ref.ReferenceQueue; import java.lang.reflect.Constructor; import java.math.BigDecimal; import java.nio.ByteBuffer; @@ -114,29 +112,29 @@ import java.util.logging.Logger; */ public class Nd4j { - public final static String DATA_BUFFER_OPS = "databufferfactory"; - public final static String CONVOLUTION_OPS = "convops"; + private final static String DATA_BUFFER_OPS = "databufferfactory"; + private final static String CONVOLUTION_OPS = "convops"; /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ @Deprecated public final static String DTYPE = ND4JSystemProperties.DTYPE; - public final static String BLAS_OPS = "blas.ops"; - public final static String SPARSE_BLAS_OPS = "sparseblas.ops"; + private final static String BLAS_OPS = "blas.ops"; + private final static String SPARSE_BLAS_OPS = "sparseblas.ops"; public final static String NATIVE_OPS = "native.ops"; - public final static String ORDER_KEY = "ndarray.order"; - public final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class"; - public final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class"; - public final static String OP_EXECUTIONER = "opexec"; - public final static String OP_FACTORY = "opfactory"; + private final static String ORDER_KEY = "ndarray.order"; + private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class"; + private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class"; + private final static String OP_EXECUTIONER = "opexec"; + public final static String DISTRIBUTION = "dist"; - public final static String SHAPEINFO_PROVIDER = "shapeinfoprovider"; - public final static String SPARSEINFO_PROVIDER = "sparseinfoprovider"; - public final static String CONSTANT_PROVIDER = "constantsprovider"; - public final static String AFFINITY_MANAGER = "affinitymanager"; + private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider"; + private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider"; + private final static String CONSTANT_PROVIDER = "constantsprovider"; + private final static String AFFINITY_MANAGER = "affinitymanager"; //disable toString() on compressed arrays for debugging. Should be off by default. - public final static String COMPRESSION_DEBUG = "compressiondebug"; - public final static String MEMORY_MANAGER = "memorymanager"; - public final static String WORKSPACE_MANAGER = "workspacemanager"; - public final static String RANDOM_PROVIDER = "random"; + private final static String COMPRESSION_DEBUG = "compressiondebug"; + private final static String MEMORY_MANAGER = "memorymanager"; + private final static String WORKSPACE_MANAGER = "workspacemanager"; + private final static String RANDOM_PROVIDER = "random"; /**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/ @Deprecated public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION; @@ -144,8 +142,7 @@ public class Nd4j { //the datatype used for allocating buffers protected static DataType dtype = DataType.FLOAT; //the allocation mode for the heap - public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.HEAP; - public static char ORDER = 'c'; + public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.MIXED_DATA_TYPES; public static double EPS_THRESHOLD = 1e-5; private static boolean allowsOrder = false; public static boolean compressDebug = false; @@ -157,45 +154,27 @@ public class Nd4j { private static final AtomicInteger numThreads = new AtomicInteger(-1); private static AtomicReference defaultFloatingPointDataType; - protected static Class workspaceManagerClazz; - protected static Class blasWrapperClazz; - protected static Class sparseBlasWrapperClazz; - protected static Class ndArrayFactoryClazz; - protected static Class sparseNDArrayClazz; - protected static Class convolutionInstanceClazz; - protected static Class dataBufferFactoryClazz; - protected static Class opExecutionerClazz; - protected static Class randomClazz; - protected static Class distributionFactoryClazz; - protected static Class shapeInfoProviderClazz; - protected static Class sparseInfoProviderClazz; - protected static Class constantProviderClazz; - protected static Class affinityManagerClazz; - protected static Class memoryManagerClazz; - - protected static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE; - protected static BlasWrapper BLAS_WRAPPER_INSTANCE; - protected static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE; + private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE; + private static BlasWrapper BLAS_WRAPPER_INSTANCE; + private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE; protected static NDArrayFactory INSTANCE; - protected static NDArrayFactory SPARSE_INSTANCE; - protected static ConvolutionInstance CONVOLUTION_INSTANCE; - protected static OpExecutioner OP_EXECUTIONER_INSTANCE; - protected static DistributionFactory DISTRIBUTION_FACTORY; - protected static ShapeInfoProvider shapeInfoProvider; - protected static SparseInfoProvider sparseInfoProvider; - protected static ConstantHandler constantHandler; - protected static AffinityManager affinityManager; - protected static MemoryManager memoryManager; + private static NDArrayFactory SPARSE_INSTANCE; + private static ConvolutionInstance CONVOLUTION_INSTANCE; + private static OpExecutioner OP_EXECUTIONER_INSTANCE; + private static DistributionFactory DISTRIBUTION_FACTORY; + private static ShapeInfoProvider shapeInfoProvider; + private static SparseInfoProvider sparseInfoProvider; + private static ConstantHandler constantHandler; + private static AffinityManager affinityManager; + private static MemoryManager memoryManager; - protected static AtomicBoolean fallbackMode; + private static AtomicBoolean fallbackMode; protected static Properties props = new Properties(); - protected static ReferenceQueue referenceQueue = new ReferenceQueue<>(); - protected static ReferenceQueue bufferQueue = new ReferenceQueue<>(); private final static Logger logger = Logger.getLogger(Nd4j.class.getName()); - protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length]; + private static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length]; static { fallbackMode = new AtomicBoolean(false); @@ -385,7 +364,6 @@ public class Nd4j { * @param toShuffle the ndarray to shuffle * @param random the random to use * @param dimension the dimension to do the shuffle - * @return */ public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) { INSTANCE.shuffle(toShuffle, random, dimension); @@ -396,10 +374,8 @@ public class Nd4j { * along a specified set of dimensions * @param toShuffle the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ public static void shuffle(INDArray toShuffle, @NonNull int... dimension) { - //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension); } @@ -408,10 +384,8 @@ public class Nd4j { * along a specified set of dimensions * @param toShuffle the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ public static void shuffle(Collection toShuffle, @NonNull int... dimension) { - //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension); } @@ -420,10 +394,8 @@ public class Nd4j { * along a specified set of dimensions * @param toShuffle the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ public static void shuffle(Collection toShuffle, Random rnd, @NonNull int... dimension) { - //shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, rnd, dimension); } @@ -433,33 +405,11 @@ public class Nd4j { * * @param toShuffle the ndarray to shuffle * @param dimensions the dimension to do the shuffle. Please note - order matters here. - * @return */ public static void shuffle(List toShuffle, Random rnd, List dimensions) { - INSTANCE.shuffle(toShuffle, rnd, dimensions); } - /** - * The reference queue used for cleaning up - * ndarrays - * - * @return the reference queue for cleaning up ndarrays - */ - public static ReferenceQueue refQueue() { - return referenceQueue; - } - - /** - * The reference queue used for cleaning up - * databuffers - * - * @return the reference queue for cleaning up databuffers - */ - public static ReferenceQueue bufferRefQueue() { - return bufferQueue; - } - /** * Get the primary distributions * factory @@ -480,9 +430,9 @@ public class Nd4j { } /** - * This method returns RandomFactory instance + * Get the RandomFactory instance * - * @return + * @return the RandomFactory instance */ public static RandomFactory getRandomFactory() { return randomFactory; @@ -500,7 +450,7 @@ public class Nd4j { /** * Set a convolution instance * - * @param convolutionInstance + * @param convolutionInstance the new convolution instance */ public static void setConvolution(ConvolutionInstance convolutionInstance) { if (convolutionInstance == null) @@ -526,7 +476,6 @@ public class Nd4j { * slice is the specified shape */ public static INDArray create(int[] sliceShape, float[]... arrays) { - //TODO: Remove duplicate code. int slices = arrays.length; INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape))); for (int i = 0; i < ret.slices(); i++) @@ -586,30 +535,6 @@ public class Nd4j { return DATA_BUFFER_FACTORY_INSTANCE; } - /** - * Given a sequence of Iterators over a transform of matrices, fill in all of - * the matrices with the entries in the theta vector. Errors are - * thrown if the theta vector does not exactly fill the matrices. - * - * TODO: unused method. - */ - public static void setParams(INDArray theta, Collection... matrices) { - int index = 0; - for (Collection matrixCollection : matrices) { - for (INDArray matrix : matrixCollection) { - INDArray linear = matrix.reshape(-1); - for (int i = 0; i < matrix.length(); i++) { - linear.putScalar(i, theta.getDouble(index)); - index++; - } - } - } - - if (index != theta.length()) { - throw new AssertionError("Did not entirely use the theta vector"); - } - } - /** * Roll the specified axis backwards, * until it lies in a given position. @@ -681,7 +606,7 @@ public class Nd4j { * @param b the right tensor * @param result the result array * @param axes the axes for each array to do matrix multiply along - * @return + * @return the result array */ public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) { int validationLength = Math.min(axes[0].length, axes[1].length); @@ -720,15 +645,7 @@ public class Nd4j { //if listA and listB are empty these donot initialize. //so initializing with {1} which will then get overriden if not empty long[] newShapeA = {-1, n2}; - //TODO: remove duplicate code. - long[] oldShapeA; - if (listA.size() == 0) { - oldShapeA = new long[] {1}; - } else { - oldShapeA = Longs.toArray(listA); - for (int i = 0; i < oldShapeA.length; i++) - oldShapeA[i] = a.size((int) oldShapeA[i]); - } + long[] oldShapeA = getOldShape(listA, a); int n3 = 1; int bNax = Math.min(b.rank(), axes[1].length); @@ -737,14 +654,7 @@ public class Nd4j { } long[] newShapeB = {n3, -1}; - long[] oldShapeB; - if (listB.size() == 0) { - oldShapeB = new long[] {1}; - } else { - oldShapeB = Longs.toArray(listB); - for (int i = 0; i < oldShapeB.length; i++) - oldShapeB[i] = b.size((int) oldShapeB[i]); - } + long[] oldShapeB = getOldShape(listB, b); INDArray at = a.permute(newAxesA).reshape(newShapeA); INDArray bt = b.permute(newAxesB).reshape(newShapeB); @@ -754,6 +664,19 @@ public class Nd4j { return ret.reshape(aPlusB); } + // Some duplicate code that refactored out: + private static long[] getOldShape(List list, INDArray x){ + long[] res; + if (list.size() == 0) { + res = new long[] {1}; + } else { + res= Longs.toArray(list); + for (int i = 0; i < res.length; i++) + res[i] = x.size((int) res[i]); + } + return res; + } + /** * Tensor matrix multiplication. * Both tensors must be the same rank @@ -761,7 +684,7 @@ public class Nd4j { * @param a the left tensor * @param b the right tensor * @param axes the axes for each array to do matrix multiply along - * @return + * @return the multiplication result. */ public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) { CustomOp op = DynamicCustomOp.builder("tensordot") @@ -892,29 +815,6 @@ public class Nd4j { return matmul(a,b, null); } - /** - * Given a sequence of Iterators over a transform of matrices, fill in all of - * the matrices with the entries in the theta vector. Errors are - * thrown if the theta vector does not exactly fill the matrices. - * TODO: unused method. - */ - public static void setParams(INDArray theta, Iterator... matrices) { - int index = 0; - for (Iterator matrixIterator : matrices) { - while (matrixIterator.hasNext()) { - INDArray matrix = matrixIterator.next().reshape(-1); - for (int i = 0; i < matrix.length(); i++) { - matrix.putScalar(i, theta.getDouble(index)); - index++; - } - } - } - - if (index != theta.length()) { - throw new AssertionError("Did not entirely use the theta vector"); - } - } - /** * The factory used for creating ndarrays * @@ -5766,45 +5666,45 @@ public class Nd4j { } compressDebug = pp.toBoolean(COMPRESSION_DEBUG); - ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C); + char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C); - affinityManagerClazz = (Class) Class + Class affinityManagerClazz = (Class) Class .forName(pp.toString(AFFINITY_MANAGER)); affinityManager = affinityManagerClazz.newInstance(); - ndArrayFactoryClazz = (Class) Class.forName( + Class ndArrayFactoryClazz = (Class) Class.forName( pp.toString(NDARRAY_FACTORY_CLASS)); - sparseNDArrayClazz = (Class) Class.forName( + Class sparseNDArrayClazz = (Class) Class.forName( pp.toString(SPARSE_NDARRAY_FACTORY_CLASS)); - convolutionInstanceClazz = (Class) Class + Class convolutionInstanceClazz = (Class) Class .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); - dataBufferFactoryClazz = (Class) Class + Class dataBufferFactoryClazz = (Class) Class .forName(pp.toString(DATA_BUFFER_OPS, defaultName)); - shapeInfoProviderClazz = (Class) Class + Class shapeInfoProviderClazz = (Class) Class .forName(pp.toString(SHAPEINFO_PROVIDER)); - sparseInfoProviderClazz = (Class) Class.forName( + Class sparseInfoProviderClazz = (Class) Class.forName( pp.toString(SPARSEINFO_PROVIDER)); - constantProviderClazz = (Class) Class + Class constantProviderClazz = (Class) Class .forName(pp.toString(CONSTANT_PROVIDER)); - memoryManagerClazz = (Class) Class + Class memoryManagerClazz = (Class) Class .forName(pp.toString(MEMORY_MANAGER)); allowsOrder = backend.allowsOrder(); String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName()); - randomClazz = (Class) Class.forName(rand); + Class randomClazz = (Class) Class.forName(rand); randomFactory = new RandomFactory(randomClazz); - workspaceManagerClazz = (Class) Class + Class workspaceManagerClazz = (Class) Class .forName(pp.toString(WORKSPACE_MANAGER)); - blasWrapperClazz = (Class) Class + Class blasWrapperClazz = (Class) Class .forName(pp.toString(BLAS_OPS)); - sparseBlasWrapperClazz = (Class) Class + Class sparseBlasWrapperClazz = (Class) Class .forName(pp.toString(SPARSE_BLAS_OPS)); String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName()); - distributionFactoryClazz = (Class) Class.forName(clazzName); + Class distributionFactoryClazz = (Class) Class.forName(clazzName); memoryManager = memoryManagerClazz.newInstance(); @@ -5813,7 +5713,7 @@ public class Nd4j { sparseInfoProvider = sparseInfoProviderClazz.newInstance(); workspaceManager = workspaceManagerClazz.newInstance(); - opExecutionerClazz = (Class) Class + Class opExecutionerClazz = (Class) Class .forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName())); OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();