Nd4j refactoring (#101)

* cleanup

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>

* wip

Signed-off-by: Robert Altena <Rob@Ra-ai.com>
master
Robert Altena 2019-08-06 16:40:08 +09:00 committed by GitHub
parent b8846113bd
commit a438434b1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 68 additions and 168 deletions

View File

@ -54,7 +54,6 @@ import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Diag; import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart; import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
@ -90,7 +89,6 @@ import org.nd4j.tools.PropertyParser;
import org.nd4j.versioncheck.VersionCheck; import org.nd4j.versioncheck.VersionCheck;
import java.io.*; import java.io.*;
import java.lang.ref.ReferenceQueue;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -114,29 +112,29 @@ import java.util.logging.Logger;
*/ */
public class Nd4j { public class Nd4j {
public final static String DATA_BUFFER_OPS = "databufferfactory"; private final static String DATA_BUFFER_OPS = "databufferfactory";
public final static String CONVOLUTION_OPS = "convops"; private final static String CONVOLUTION_OPS = "convops";
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/ /**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
@Deprecated @Deprecated
public final static String DTYPE = ND4JSystemProperties.DTYPE; public final static String DTYPE = ND4JSystemProperties.DTYPE;
public final static String BLAS_OPS = "blas.ops"; private final static String BLAS_OPS = "blas.ops";
public final static String SPARSE_BLAS_OPS = "sparseblas.ops"; private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
public final static String NATIVE_OPS = "native.ops"; public final static String NATIVE_OPS = "native.ops";
public final static String ORDER_KEY = "ndarray.order"; private final static String ORDER_KEY = "ndarray.order";
public final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class"; private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
public final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class"; private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
public final static String OP_EXECUTIONER = "opexec"; private final static String OP_EXECUTIONER = "opexec";
public final static String OP_FACTORY = "opfactory";
public final static String DISTRIBUTION = "dist"; public final static String DISTRIBUTION = "dist";
public final static String SHAPEINFO_PROVIDER = "shapeinfoprovider"; private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
public final static String SPARSEINFO_PROVIDER = "sparseinfoprovider"; private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
public final static String CONSTANT_PROVIDER = "constantsprovider"; private final static String CONSTANT_PROVIDER = "constantsprovider";
public final static String AFFINITY_MANAGER = "affinitymanager"; private final static String AFFINITY_MANAGER = "affinitymanager";
//disable toString() on compressed arrays for debugging. Should be off by default. //disable toString() on compressed arrays for debugging. Should be off by default.
public final static String COMPRESSION_DEBUG = "compressiondebug"; private final static String COMPRESSION_DEBUG = "compressiondebug";
public final static String MEMORY_MANAGER = "memorymanager"; private final static String MEMORY_MANAGER = "memorymanager";
public final static String WORKSPACE_MANAGER = "workspacemanager"; private final static String WORKSPACE_MANAGER = "workspacemanager";
public final static String RANDOM_PROVIDER = "random"; private final static String RANDOM_PROVIDER = "random";
/**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/ /**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/
@Deprecated @Deprecated
public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION; public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION;
@ -144,8 +142,7 @@ public class Nd4j {
//the datatype used for allocating buffers //the datatype used for allocating buffers
protected static DataType dtype = DataType.FLOAT; protected static DataType dtype = DataType.FLOAT;
//the allocation mode for the heap //the allocation mode for the heap
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.HEAP; public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
public static char ORDER = 'c';
public static double EPS_THRESHOLD = 1e-5; public static double EPS_THRESHOLD = 1e-5;
private static boolean allowsOrder = false; private static boolean allowsOrder = false;
public static boolean compressDebug = false; public static boolean compressDebug = false;
@ -157,45 +154,27 @@ public class Nd4j {
private static final AtomicInteger numThreads = new AtomicInteger(-1); private static final AtomicInteger numThreads = new AtomicInteger(-1);
private static AtomicReference<DataType> defaultFloatingPointDataType; private static AtomicReference<DataType> defaultFloatingPointDataType;
protected static Class<? extends MemoryWorkspaceManager> workspaceManagerClazz; private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
protected static Class<? extends BlasWrapper> blasWrapperClazz; private static BlasWrapper BLAS_WRAPPER_INSTANCE;
protected static Class<? extends BlasWrapper> sparseBlasWrapperClazz; private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
protected static Class<? extends NDArrayFactory> ndArrayFactoryClazz;
protected static Class<? extends NDArrayFactory> sparseNDArrayClazz;
protected static Class<? extends ConvolutionInstance> convolutionInstanceClazz;
protected static Class<? extends DataBufferFactory> dataBufferFactoryClazz;
protected static Class<? extends OpExecutioner> opExecutionerClazz;
protected static Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz;
protected static Class<? extends DistributionFactory> distributionFactoryClazz;
protected static Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz;
protected static Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz;
protected static Class<? extends BasicConstantHandler> constantProviderClazz;
protected static Class<? extends BasicAffinityManager> affinityManagerClazz;
protected static Class<? extends BasicMemoryManager> memoryManagerClazz;
protected static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
protected static BlasWrapper BLAS_WRAPPER_INSTANCE;
protected static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
protected static NDArrayFactory INSTANCE; protected static NDArrayFactory INSTANCE;
protected static NDArrayFactory SPARSE_INSTANCE; private static NDArrayFactory SPARSE_INSTANCE;
protected static ConvolutionInstance CONVOLUTION_INSTANCE; private static ConvolutionInstance CONVOLUTION_INSTANCE;
protected static OpExecutioner OP_EXECUTIONER_INSTANCE; private static OpExecutioner OP_EXECUTIONER_INSTANCE;
protected static DistributionFactory DISTRIBUTION_FACTORY; private static DistributionFactory DISTRIBUTION_FACTORY;
protected static ShapeInfoProvider shapeInfoProvider; private static ShapeInfoProvider shapeInfoProvider;
protected static SparseInfoProvider sparseInfoProvider; private static SparseInfoProvider sparseInfoProvider;
protected static ConstantHandler constantHandler; private static ConstantHandler constantHandler;
protected static AffinityManager affinityManager; private static AffinityManager affinityManager;
protected static MemoryManager memoryManager; private static MemoryManager memoryManager;
protected static AtomicBoolean fallbackMode; private static AtomicBoolean fallbackMode;
protected static Properties props = new Properties(); protected static Properties props = new Properties();
protected static ReferenceQueue<INDArray> referenceQueue = new ReferenceQueue<>();
protected static ReferenceQueue<DataBuffer> bufferQueue = new ReferenceQueue<>();
private final static Logger logger = Logger.getLogger(Nd4j.class.getName()); private final static Logger logger = Logger.getLogger(Nd4j.class.getName());
protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length]; private static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
static { static {
fallbackMode = new AtomicBoolean(false); fallbackMode = new AtomicBoolean(false);
@ -385,7 +364,6 @@ public class Nd4j {
* @param toShuffle the ndarray to shuffle * @param toShuffle the ndarray to shuffle
* @param random the random to use * @param random the random to use
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) { public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
INSTANCE.shuffle(toShuffle, random, dimension); INSTANCE.shuffle(toShuffle, random, dimension);
@ -396,10 +374,8 @@ public class Nd4j {
* along a specified set of dimensions * along a specified set of dimensions
* @param toShuffle the ndarray to shuffle * @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) { public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension);
} }
@ -408,10 +384,8 @@ public class Nd4j {
* along a specified set of dimensions * along a specified set of dimensions
* @param toShuffle the ndarray to shuffle * @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) { public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, new Random(), dimension); INSTANCE.shuffle(toShuffle, new Random(), dimension);
} }
@ -420,10 +394,8 @@ public class Nd4j {
* along a specified set of dimensions * along a specified set of dimensions
* @param toShuffle the ndarray to shuffle * @param toShuffle the ndarray to shuffle
* @param dimension the dimension to do the shuffle * @param dimension the dimension to do the shuffle
* @return
*/ */
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) { public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
//shuffle(toShuffle, new Random(), dimension);
INSTANCE.shuffle(toShuffle, rnd, dimension); INSTANCE.shuffle(toShuffle, rnd, dimension);
} }
@ -433,33 +405,11 @@ public class Nd4j {
* *
* @param toShuffle the ndarray to shuffle * @param toShuffle the ndarray to shuffle
* @param dimensions the dimension to do the shuffle. Please note - order matters here. * @param dimensions the dimension to do the shuffle. Please note - order matters here.
* @return
*/ */
public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) { public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) {
INSTANCE.shuffle(toShuffle, rnd, dimensions); INSTANCE.shuffle(toShuffle, rnd, dimensions);
} }
/**
* The reference queue used for cleaning up
* ndarrays
*
* @return the reference queue for cleaning up ndarrays
*/
public static ReferenceQueue<INDArray> refQueue() {
return referenceQueue;
}
/**
* The reference queue used for cleaning up
* databuffers
*
* @return the reference queue for cleaning up databuffers
*/
public static ReferenceQueue<DataBuffer> bufferRefQueue() {
return bufferQueue;
}
/** /**
* Get the primary distributions * Get the primary distributions
* factory * factory
@ -480,9 +430,9 @@ public class Nd4j {
} }
/** /**
* This method returns RandomFactory instance * Get the RandomFactory instance
* *
* @return * @return the RandomFactory instance
*/ */
public static RandomFactory getRandomFactory() { public static RandomFactory getRandomFactory() {
return randomFactory; return randomFactory;
@ -500,7 +450,7 @@ public class Nd4j {
/** /**
* Set a convolution instance * Set a convolution instance
* *
* @param convolutionInstance * @param convolutionInstance the new convolution instance
*/ */
public static void setConvolution(ConvolutionInstance convolutionInstance) { public static void setConvolution(ConvolutionInstance convolutionInstance) {
if (convolutionInstance == null) if (convolutionInstance == null)
@ -526,7 +476,6 @@ public class Nd4j {
* slice is the specified shape * slice is the specified shape
*/ */
public static INDArray create(int[] sliceShape, float[]... arrays) { public static INDArray create(int[] sliceShape, float[]... arrays) {
//TODO: Remove duplicate code.
int slices = arrays.length; int slices = arrays.length;
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape))); INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++) for (int i = 0; i < ret.slices(); i++)
@ -586,30 +535,6 @@ public class Nd4j {
return DATA_BUFFER_FACTORY_INSTANCE; return DATA_BUFFER_FACTORY_INSTANCE;
} }
/**
* Given a sequence of Iterators over a transform of matrices, fill in all of
* the matrices with the entries in the theta vector. Errors are
* thrown if the theta vector does not exactly fill the matrices.
*
* TODO: unused method.
*/
public static void setParams(INDArray theta, Collection<INDArray>... matrices) {
int index = 0;
for (Collection<INDArray> matrixCollection : matrices) {
for (INDArray matrix : matrixCollection) {
INDArray linear = matrix.reshape(-1);
for (int i = 0; i < matrix.length(); i++) {
linear.putScalar(i, theta.getDouble(index));
index++;
}
}
}
if (index != theta.length()) {
throw new AssertionError("Did not entirely use the theta vector");
}
}
/** /**
* Roll the specified axis backwards, * Roll the specified axis backwards,
* until it lies in a given position. * until it lies in a given position.
@ -681,7 +606,7 @@ public class Nd4j {
* @param b the right tensor * @param b the right tensor
* @param result the result array * @param result the result array
* @param axes the axes for each array to do matrix multiply along * @param axes the axes for each array to do matrix multiply along
* @return * @return the result array
*/ */
public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) { public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) {
int validationLength = Math.min(axes[0].length, axes[1].length); int validationLength = Math.min(axes[0].length, axes[1].length);
@ -720,15 +645,7 @@ public class Nd4j {
//if listA and listB are empty these donot initialize. //if listA and listB are empty these donot initialize.
//so initializing with {1} which will then get overriden if not empty //so initializing with {1} which will then get overriden if not empty
long[] newShapeA = {-1, n2}; long[] newShapeA = {-1, n2};
//TODO: remove duplicate code. long[] oldShapeA = getOldShape(listA, a);
long[] oldShapeA;
if (listA.size() == 0) {
oldShapeA = new long[] {1};
} else {
oldShapeA = Longs.toArray(listA);
for (int i = 0; i < oldShapeA.length; i++)
oldShapeA[i] = a.size((int) oldShapeA[i]);
}
int n3 = 1; int n3 = 1;
int bNax = Math.min(b.rank(), axes[1].length); int bNax = Math.min(b.rank(), axes[1].length);
@ -737,14 +654,7 @@ public class Nd4j {
} }
long[] newShapeB = {n3, -1}; long[] newShapeB = {n3, -1};
long[] oldShapeB; long[] oldShapeB = getOldShape(listB, b);
if (listB.size() == 0) {
oldShapeB = new long[] {1};
} else {
oldShapeB = Longs.toArray(listB);
for (int i = 0; i < oldShapeB.length; i++)
oldShapeB[i] = b.size((int) oldShapeB[i]);
}
INDArray at = a.permute(newAxesA).reshape(newShapeA); INDArray at = a.permute(newAxesA).reshape(newShapeA);
INDArray bt = b.permute(newAxesB).reshape(newShapeB); INDArray bt = b.permute(newAxesB).reshape(newShapeB);
@ -754,6 +664,19 @@ public class Nd4j {
return ret.reshape(aPlusB); return ret.reshape(aPlusB);
} }
// Some duplicate code that refactored out:
private static long[] getOldShape(List<Integer> list, INDArray x){
long[] res;
if (list.size() == 0) {
res = new long[] {1};
} else {
res= Longs.toArray(list);
for (int i = 0; i < res.length; i++)
res[i] = x.size((int) res[i]);
}
return res;
}
/** /**
* Tensor matrix multiplication. * Tensor matrix multiplication.
* Both tensors must be the same rank * Both tensors must be the same rank
@ -761,7 +684,7 @@ public class Nd4j {
* @param a the left tensor * @param a the left tensor
* @param b the right tensor * @param b the right tensor
* @param axes the axes for each array to do matrix multiply along * @param axes the axes for each array to do matrix multiply along
* @return * @return the multiplication result.
*/ */
public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) { public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) {
CustomOp op = DynamicCustomOp.builder("tensordot") CustomOp op = DynamicCustomOp.builder("tensordot")
@ -892,29 +815,6 @@ public class Nd4j {
return matmul(a,b, null); return matmul(a,b, null);
} }
/**
* Given a sequence of Iterators over a transform of matrices, fill in all of
* the matrices with the entries in the theta vector. Errors are
* thrown if the theta vector does not exactly fill the matrices.
* TODO: unused method.
*/
public static void setParams(INDArray theta, Iterator<? extends INDArray>... matrices) {
int index = 0;
for (Iterator<? extends INDArray> matrixIterator : matrices) {
while (matrixIterator.hasNext()) {
INDArray matrix = matrixIterator.next().reshape(-1);
for (int i = 0; i < matrix.length(); i++) {
matrix.putScalar(i, theta.getDouble(index));
index++;
}
}
}
if (index != theta.length()) {
throw new AssertionError("Did not entirely use the theta vector");
}
}
/** /**
* The factory used for creating ndarrays * The factory used for creating ndarrays
* *
@ -5766,45 +5666,45 @@ public class Nd4j {
} }
compressDebug = pp.toBoolean(COMPRESSION_DEBUG); compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C); char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class Class<? extends BasicAffinityManager> affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
.forName(pp.toString(AFFINITY_MANAGER)); .forName(pp.toString(AFFINITY_MANAGER));
affinityManager = affinityManagerClazz.newInstance(); affinityManager = affinityManagerClazz.newInstance();
ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName( Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(NDARRAY_FACTORY_CLASS)); pp.toString(NDARRAY_FACTORY_CLASS));
sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName( Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS)); pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName())); .forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName()); String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
.forName(pp.toString(DATA_BUFFER_OPS, defaultName)); .forName(pp.toString(DATA_BUFFER_OPS, defaultName));
shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
.forName(pp.toString(SHAPEINFO_PROVIDER)); .forName(pp.toString(SHAPEINFO_PROVIDER));
sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName( Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
pp.toString(SPARSEINFO_PROVIDER)); pp.toString(SPARSEINFO_PROVIDER));
constantProviderClazz = (Class<? extends BasicConstantHandler>) Class Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
.forName(pp.toString(CONSTANT_PROVIDER)); .forName(pp.toString(CONSTANT_PROVIDER));
memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class Class<? extends BasicMemoryManager> memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
.forName(pp.toString(MEMORY_MANAGER)); .forName(pp.toString(MEMORY_MANAGER));
allowsOrder = backend.allowsOrder(); allowsOrder = backend.allowsOrder();
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName()); String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand); Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
randomFactory = new RandomFactory(randomClazz); randomFactory = new RandomFactory(randomClazz);
workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
.forName(pp.toString(WORKSPACE_MANAGER)); .forName(pp.toString(WORKSPACE_MANAGER));
blasWrapperClazz = (Class<? extends BlasWrapper>) Class Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(BLAS_OPS)); .forName(pp.toString(BLAS_OPS));
sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
.forName(pp.toString(SPARSE_BLAS_OPS)); .forName(pp.toString(SPARSE_BLAS_OPS));
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName()); String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName); Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
memoryManager = memoryManagerClazz.newInstance(); memoryManager = memoryManagerClazz.newInstance();
@ -5813,7 +5713,7 @@ public class Nd4j {
sparseInfoProvider = sparseInfoProviderClazz.newInstance(); sparseInfoProvider = sparseInfoProviderClazz.newInstance();
workspaceManager = workspaceManagerClazz.newInstance(); workspaceManager = workspaceManagerClazz.newInstance();
opExecutionerClazz = (Class<? extends OpExecutioner>) Class Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName())); .forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance(); OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();