Nd4j refactoring (#101)
* cleanup Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
b8846113bd
commit
a438434b1f
|
@ -54,7 +54,6 @@ import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
|||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
|
||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
||||
|
@ -90,7 +89,6 @@ import org.nd4j.tools.PropertyParser;
|
|||
import org.nd4j.versioncheck.VersionCheck;
|
||||
|
||||
import java.io.*;
|
||||
import java.lang.ref.ReferenceQueue;
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -114,29 +112,29 @@ import java.util.logging.Logger;
|
|||
*/
|
||||
public class Nd4j {
|
||||
|
||||
public final static String DATA_BUFFER_OPS = "databufferfactory";
|
||||
public final static String CONVOLUTION_OPS = "convops";
|
||||
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
||||
private final static String CONVOLUTION_OPS = "convops";
|
||||
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
||||
@Deprecated
|
||||
public final static String DTYPE = ND4JSystemProperties.DTYPE;
|
||||
public final static String BLAS_OPS = "blas.ops";
|
||||
public final static String SPARSE_BLAS_OPS = "sparseblas.ops";
|
||||
private final static String BLAS_OPS = "blas.ops";
|
||||
private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
|
||||
public final static String NATIVE_OPS = "native.ops";
|
||||
public final static String ORDER_KEY = "ndarray.order";
|
||||
public final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
||||
public final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
|
||||
public final static String OP_EXECUTIONER = "opexec";
|
||||
public final static String OP_FACTORY = "opfactory";
|
||||
private final static String ORDER_KEY = "ndarray.order";
|
||||
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
||||
private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
|
||||
private final static String OP_EXECUTIONER = "opexec";
|
||||
|
||||
public final static String DISTRIBUTION = "dist";
|
||||
public final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
||||
public final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
|
||||
public final static String CONSTANT_PROVIDER = "constantsprovider";
|
||||
public final static String AFFINITY_MANAGER = "affinitymanager";
|
||||
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
||||
private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
|
||||
private final static String CONSTANT_PROVIDER = "constantsprovider";
|
||||
private final static String AFFINITY_MANAGER = "affinitymanager";
|
||||
//disable toString() on compressed arrays for debugging. Should be off by default.
|
||||
public final static String COMPRESSION_DEBUG = "compressiondebug";
|
||||
public final static String MEMORY_MANAGER = "memorymanager";
|
||||
public final static String WORKSPACE_MANAGER = "workspacemanager";
|
||||
public final static String RANDOM_PROVIDER = "random";
|
||||
private final static String COMPRESSION_DEBUG = "compressiondebug";
|
||||
private final static String MEMORY_MANAGER = "memorymanager";
|
||||
private final static String WORKSPACE_MANAGER = "workspacemanager";
|
||||
private final static String RANDOM_PROVIDER = "random";
|
||||
/**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/
|
||||
@Deprecated
|
||||
public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION;
|
||||
|
@ -144,8 +142,7 @@ public class Nd4j {
|
|||
//the datatype used for allocating buffers
|
||||
protected static DataType dtype = DataType.FLOAT;
|
||||
//the allocation mode for the heap
|
||||
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.HEAP;
|
||||
public static char ORDER = 'c';
|
||||
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
|
||||
public static double EPS_THRESHOLD = 1e-5;
|
||||
private static boolean allowsOrder = false;
|
||||
public static boolean compressDebug = false;
|
||||
|
@ -157,45 +154,27 @@ public class Nd4j {
|
|||
private static final AtomicInteger numThreads = new AtomicInteger(-1);
|
||||
private static AtomicReference<DataType> defaultFloatingPointDataType;
|
||||
|
||||
protected static Class<? extends MemoryWorkspaceManager> workspaceManagerClazz;
|
||||
protected static Class<? extends BlasWrapper> blasWrapperClazz;
|
||||
protected static Class<? extends BlasWrapper> sparseBlasWrapperClazz;
|
||||
protected static Class<? extends NDArrayFactory> ndArrayFactoryClazz;
|
||||
protected static Class<? extends NDArrayFactory> sparseNDArrayClazz;
|
||||
protected static Class<? extends ConvolutionInstance> convolutionInstanceClazz;
|
||||
protected static Class<? extends DataBufferFactory> dataBufferFactoryClazz;
|
||||
protected static Class<? extends OpExecutioner> opExecutionerClazz;
|
||||
protected static Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz;
|
||||
protected static Class<? extends DistributionFactory> distributionFactoryClazz;
|
||||
protected static Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz;
|
||||
protected static Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz;
|
||||
protected static Class<? extends BasicConstantHandler> constantProviderClazz;
|
||||
protected static Class<? extends BasicAffinityManager> affinityManagerClazz;
|
||||
protected static Class<? extends BasicMemoryManager> memoryManagerClazz;
|
||||
|
||||
protected static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
||||
protected static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
||||
protected static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
|
||||
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
||||
private static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
||||
private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
|
||||
protected static NDArrayFactory INSTANCE;
|
||||
protected static NDArrayFactory SPARSE_INSTANCE;
|
||||
protected static ConvolutionInstance CONVOLUTION_INSTANCE;
|
||||
protected static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
||||
protected static DistributionFactory DISTRIBUTION_FACTORY;
|
||||
protected static ShapeInfoProvider shapeInfoProvider;
|
||||
protected static SparseInfoProvider sparseInfoProvider;
|
||||
protected static ConstantHandler constantHandler;
|
||||
protected static AffinityManager affinityManager;
|
||||
protected static MemoryManager memoryManager;
|
||||
private static NDArrayFactory SPARSE_INSTANCE;
|
||||
private static ConvolutionInstance CONVOLUTION_INSTANCE;
|
||||
private static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
||||
private static DistributionFactory DISTRIBUTION_FACTORY;
|
||||
private static ShapeInfoProvider shapeInfoProvider;
|
||||
private static SparseInfoProvider sparseInfoProvider;
|
||||
private static ConstantHandler constantHandler;
|
||||
private static AffinityManager affinityManager;
|
||||
private static MemoryManager memoryManager;
|
||||
|
||||
protected static AtomicBoolean fallbackMode;
|
||||
private static AtomicBoolean fallbackMode;
|
||||
|
||||
protected static Properties props = new Properties();
|
||||
protected static ReferenceQueue<INDArray> referenceQueue = new ReferenceQueue<>();
|
||||
protected static ReferenceQueue<DataBuffer> bufferQueue = new ReferenceQueue<>();
|
||||
|
||||
private final static Logger logger = Logger.getLogger(Nd4j.class.getName());
|
||||
|
||||
protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
|
||||
private static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
|
||||
|
||||
static {
|
||||
fallbackMode = new AtomicBoolean(false);
|
||||
|
@ -385,7 +364,6 @@ public class Nd4j {
|
|||
* @param toShuffle the ndarray to shuffle
|
||||
* @param random the random to use
|
||||
* @param dimension the dimension to do the shuffle
|
||||
* @return
|
||||
*/
|
||||
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
|
||||
INSTANCE.shuffle(toShuffle, random, dimension);
|
||||
|
@ -396,10 +374,8 @@ public class Nd4j {
|
|||
* along a specified set of dimensions
|
||||
* @param toShuffle the ndarray to shuffle
|
||||
* @param dimension the dimension to do the shuffle
|
||||
* @return
|
||||
*/
|
||||
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
|
||||
//shuffle(toShuffle, new Random(), dimension);
|
||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||
}
|
||||
|
||||
|
@ -408,10 +384,8 @@ public class Nd4j {
|
|||
* along a specified set of dimensions
|
||||
* @param toShuffle the ndarray to shuffle
|
||||
* @param dimension the dimension to do the shuffle
|
||||
* @return
|
||||
*/
|
||||
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
|
||||
//shuffle(toShuffle, new Random(), dimension);
|
||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||
}
|
||||
|
||||
|
@ -420,10 +394,8 @@ public class Nd4j {
|
|||
* along a specified set of dimensions
|
||||
* @param toShuffle the ndarray to shuffle
|
||||
* @param dimension the dimension to do the shuffle
|
||||
* @return
|
||||
*/
|
||||
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
|
||||
//shuffle(toShuffle, new Random(), dimension);
|
||||
INSTANCE.shuffle(toShuffle, rnd, dimension);
|
||||
}
|
||||
|
||||
|
@ -433,33 +405,11 @@ public class Nd4j {
|
|||
*
|
||||
* @param toShuffle the ndarray to shuffle
|
||||
* @param dimensions the dimension to do the shuffle. Please note - order matters here.
|
||||
* @return
|
||||
*/
|
||||
public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) {
|
||||
|
||||
INSTANCE.shuffle(toShuffle, rnd, dimensions);
|
||||
}
|
||||
|
||||
/**
|
||||
* The reference queue used for cleaning up
|
||||
* ndarrays
|
||||
*
|
||||
* @return the reference queue for cleaning up ndarrays
|
||||
*/
|
||||
public static ReferenceQueue<INDArray> refQueue() {
|
||||
return referenceQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
* The reference queue used for cleaning up
|
||||
* databuffers
|
||||
*
|
||||
* @return the reference queue for cleaning up databuffers
|
||||
*/
|
||||
public static ReferenceQueue<DataBuffer> bufferRefQueue() {
|
||||
return bufferQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the primary distributions
|
||||
* factory
|
||||
|
@ -480,9 +430,9 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
/**
|
||||
* This method returns RandomFactory instance
|
||||
* Get the RandomFactory instance
|
||||
*
|
||||
* @return
|
||||
* @return the RandomFactory instance
|
||||
*/
|
||||
public static RandomFactory getRandomFactory() {
|
||||
return randomFactory;
|
||||
|
@ -500,7 +450,7 @@ public class Nd4j {
|
|||
/**
|
||||
* Set a convolution instance
|
||||
*
|
||||
* @param convolutionInstance
|
||||
* @param convolutionInstance the new convolution instance
|
||||
*/
|
||||
public static void setConvolution(ConvolutionInstance convolutionInstance) {
|
||||
if (convolutionInstance == null)
|
||||
|
@ -526,7 +476,6 @@ public class Nd4j {
|
|||
* slice is the specified shape
|
||||
*/
|
||||
public static INDArray create(int[] sliceShape, float[]... arrays) {
|
||||
//TODO: Remove duplicate code.
|
||||
int slices = arrays.length;
|
||||
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
|
||||
for (int i = 0; i < ret.slices(); i++)
|
||||
|
@ -586,30 +535,6 @@ public class Nd4j {
|
|||
return DATA_BUFFER_FACTORY_INSTANCE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a sequence of Iterators over a transform of matrices, fill in all of
|
||||
* the matrices with the entries in the theta vector. Errors are
|
||||
* thrown if the theta vector does not exactly fill the matrices.
|
||||
*
|
||||
* TODO: unused method.
|
||||
*/
|
||||
public static void setParams(INDArray theta, Collection<INDArray>... matrices) {
|
||||
int index = 0;
|
||||
for (Collection<INDArray> matrixCollection : matrices) {
|
||||
for (INDArray matrix : matrixCollection) {
|
||||
INDArray linear = matrix.reshape(-1);
|
||||
for (int i = 0; i < matrix.length(); i++) {
|
||||
linear.putScalar(i, theta.getDouble(index));
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (index != theta.length()) {
|
||||
throw new AssertionError("Did not entirely use the theta vector");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Roll the specified axis backwards,
|
||||
* until it lies in a given position.
|
||||
|
@ -681,7 +606,7 @@ public class Nd4j {
|
|||
* @param b the right tensor
|
||||
* @param result the result array
|
||||
* @param axes the axes for each array to do matrix multiply along
|
||||
* @return
|
||||
* @return the result array
|
||||
*/
|
||||
public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) {
|
||||
int validationLength = Math.min(axes[0].length, axes[1].length);
|
||||
|
@ -720,15 +645,7 @@ public class Nd4j {
|
|||
//if listA and listB are empty these donot initialize.
|
||||
//so initializing with {1} which will then get overriden if not empty
|
||||
long[] newShapeA = {-1, n2};
|
||||
//TODO: remove duplicate code.
|
||||
long[] oldShapeA;
|
||||
if (listA.size() == 0) {
|
||||
oldShapeA = new long[] {1};
|
||||
} else {
|
||||
oldShapeA = Longs.toArray(listA);
|
||||
for (int i = 0; i < oldShapeA.length; i++)
|
||||
oldShapeA[i] = a.size((int) oldShapeA[i]);
|
||||
}
|
||||
long[] oldShapeA = getOldShape(listA, a);
|
||||
|
||||
int n3 = 1;
|
||||
int bNax = Math.min(b.rank(), axes[1].length);
|
||||
|
@ -737,14 +654,7 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
long[] newShapeB = {n3, -1};
|
||||
long[] oldShapeB;
|
||||
if (listB.size() == 0) {
|
||||
oldShapeB = new long[] {1};
|
||||
} else {
|
||||
oldShapeB = Longs.toArray(listB);
|
||||
for (int i = 0; i < oldShapeB.length; i++)
|
||||
oldShapeB[i] = b.size((int) oldShapeB[i]);
|
||||
}
|
||||
long[] oldShapeB = getOldShape(listB, b);
|
||||
|
||||
INDArray at = a.permute(newAxesA).reshape(newShapeA);
|
||||
INDArray bt = b.permute(newAxesB).reshape(newShapeB);
|
||||
|
@ -754,6 +664,19 @@ public class Nd4j {
|
|||
return ret.reshape(aPlusB);
|
||||
}
|
||||
|
||||
// Some duplicate code that refactored out:
|
||||
private static long[] getOldShape(List<Integer> list, INDArray x){
|
||||
long[] res;
|
||||
if (list.size() == 0) {
|
||||
res = new long[] {1};
|
||||
} else {
|
||||
res= Longs.toArray(list);
|
||||
for (int i = 0; i < res.length; i++)
|
||||
res[i] = x.size((int) res[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tensor matrix multiplication.
|
||||
* Both tensors must be the same rank
|
||||
|
@ -761,7 +684,7 @@ public class Nd4j {
|
|||
* @param a the left tensor
|
||||
* @param b the right tensor
|
||||
* @param axes the axes for each array to do matrix multiply along
|
||||
* @return
|
||||
* @return the multiplication result.
|
||||
*/
|
||||
public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) {
|
||||
CustomOp op = DynamicCustomOp.builder("tensordot")
|
||||
|
@ -892,29 +815,6 @@ public class Nd4j {
|
|||
return matmul(a,b, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a sequence of Iterators over a transform of matrices, fill in all of
|
||||
* the matrices with the entries in the theta vector. Errors are
|
||||
* thrown if the theta vector does not exactly fill the matrices.
|
||||
* TODO: unused method.
|
||||
*/
|
||||
public static void setParams(INDArray theta, Iterator<? extends INDArray>... matrices) {
|
||||
int index = 0;
|
||||
for (Iterator<? extends INDArray> matrixIterator : matrices) {
|
||||
while (matrixIterator.hasNext()) {
|
||||
INDArray matrix = matrixIterator.next().reshape(-1);
|
||||
for (int i = 0; i < matrix.length(); i++) {
|
||||
matrix.putScalar(i, theta.getDouble(index));
|
||||
index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (index != theta.length()) {
|
||||
throw new AssertionError("Did not entirely use the theta vector");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The factory used for creating ndarrays
|
||||
*
|
||||
|
@ -5766,45 +5666,45 @@ public class Nd4j {
|
|||
}
|
||||
|
||||
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
|
||||
ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
||||
char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
||||
|
||||
affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
|
||||
Class<? extends BasicAffinityManager> affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
|
||||
.forName(pp.toString(AFFINITY_MANAGER));
|
||||
affinityManager = affinityManagerClazz.newInstance();
|
||||
ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||
pp.toString(NDARRAY_FACTORY_CLASS));
|
||||
sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||
Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
|
||||
convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
||||
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
||||
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
||||
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
||||
dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
||||
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
||||
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
||||
shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
||||
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
||||
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
||||
sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
|
||||
Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
|
||||
pp.toString(SPARSEINFO_PROVIDER));
|
||||
|
||||
constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
||||
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
||||
.forName(pp.toString(CONSTANT_PROVIDER));
|
||||
|
||||
memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
|
||||
Class<? extends BasicMemoryManager> memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
|
||||
.forName(pp.toString(MEMORY_MANAGER));
|
||||
|
||||
allowsOrder = backend.allowsOrder();
|
||||
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
|
||||
randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
|
||||
Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
|
||||
randomFactory = new RandomFactory(randomClazz);
|
||||
|
||||
workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
|
||||
Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
|
||||
.forName(pp.toString(WORKSPACE_MANAGER));
|
||||
|
||||
blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||
.forName(pp.toString(BLAS_OPS));
|
||||
sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||
Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||
.forName(pp.toString(SPARSE_BLAS_OPS));
|
||||
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
||||
distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
||||
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
||||
|
||||
|
||||
memoryManager = memoryManagerClazz.newInstance();
|
||||
|
@ -5813,7 +5713,7 @@ public class Nd4j {
|
|||
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
|
||||
workspaceManager = workspaceManagerClazz.newInstance();
|
||||
|
||||
opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
||||
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
||||
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
||||
|
||||
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
||||
|
|
Loading…
Reference in New Issue