Nd4j refactoring (#101)
* cleanup Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com> * wip Signed-off-by: Robert Altena <Rob@Ra-ai.com>master
parent
b8846113bd
commit
a438434b1f
|
@ -54,7 +54,6 @@ import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
import org.nd4j.linalg.api.ops.impl.shape.Diag;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
||||||
|
@ -90,7 +89,6 @@ import org.nd4j.tools.PropertyParser;
|
||||||
import org.nd4j.versioncheck.VersionCheck;
|
import org.nd4j.versioncheck.VersionCheck;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.ref.ReferenceQueue;
|
|
||||||
import java.lang.reflect.Constructor;
|
import java.lang.reflect.Constructor;
|
||||||
import java.math.BigDecimal;
|
import java.math.BigDecimal;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -114,29 +112,29 @@ import java.util.logging.Logger;
|
||||||
*/
|
*/
|
||||||
public class Nd4j {
|
public class Nd4j {
|
||||||
|
|
||||||
public final static String DATA_BUFFER_OPS = "databufferfactory";
|
private final static String DATA_BUFFER_OPS = "databufferfactory";
|
||||||
public final static String CONVOLUTION_OPS = "convops";
|
private final static String CONVOLUTION_OPS = "convops";
|
||||||
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
/**@deprecated Use {@link ND4JSystemProperties#DTYPE}*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public final static String DTYPE = ND4JSystemProperties.DTYPE;
|
public final static String DTYPE = ND4JSystemProperties.DTYPE;
|
||||||
public final static String BLAS_OPS = "blas.ops";
|
private final static String BLAS_OPS = "blas.ops";
|
||||||
public final static String SPARSE_BLAS_OPS = "sparseblas.ops";
|
private final static String SPARSE_BLAS_OPS = "sparseblas.ops";
|
||||||
public final static String NATIVE_OPS = "native.ops";
|
public final static String NATIVE_OPS = "native.ops";
|
||||||
public final static String ORDER_KEY = "ndarray.order";
|
private final static String ORDER_KEY = "ndarray.order";
|
||||||
public final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
private final static String NDARRAY_FACTORY_CLASS = "ndarrayfactory.class";
|
||||||
public final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
|
private final static String SPARSE_NDARRAY_FACTORY_CLASS = "sparsendarrayfactory.class";
|
||||||
public final static String OP_EXECUTIONER = "opexec";
|
private final static String OP_EXECUTIONER = "opexec";
|
||||||
public final static String OP_FACTORY = "opfactory";
|
|
||||||
public final static String DISTRIBUTION = "dist";
|
public final static String DISTRIBUTION = "dist";
|
||||||
public final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
private final static String SHAPEINFO_PROVIDER = "shapeinfoprovider";
|
||||||
public final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
|
private final static String SPARSEINFO_PROVIDER = "sparseinfoprovider";
|
||||||
public final static String CONSTANT_PROVIDER = "constantsprovider";
|
private final static String CONSTANT_PROVIDER = "constantsprovider";
|
||||||
public final static String AFFINITY_MANAGER = "affinitymanager";
|
private final static String AFFINITY_MANAGER = "affinitymanager";
|
||||||
//disable toString() on compressed arrays for debugging. Should be off by default.
|
//disable toString() on compressed arrays for debugging. Should be off by default.
|
||||||
public final static String COMPRESSION_DEBUG = "compressiondebug";
|
private final static String COMPRESSION_DEBUG = "compressiondebug";
|
||||||
public final static String MEMORY_MANAGER = "memorymanager";
|
private final static String MEMORY_MANAGER = "memorymanager";
|
||||||
public final static String WORKSPACE_MANAGER = "workspacemanager";
|
private final static String WORKSPACE_MANAGER = "workspacemanager";
|
||||||
public final static String RANDOM_PROVIDER = "random";
|
private final static String RANDOM_PROVIDER = "random";
|
||||||
/**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/
|
/**@deprecated Use {@link ND4JSystemProperties#LOG_INITIALIZATION}*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION;
|
public static final String LOG_INIT_ENV_PROPERTY = ND4JSystemProperties.LOG_INITIALIZATION;
|
||||||
|
@ -144,8 +142,7 @@ public class Nd4j {
|
||||||
//the datatype used for allocating buffers
|
//the datatype used for allocating buffers
|
||||||
protected static DataType dtype = DataType.FLOAT;
|
protected static DataType dtype = DataType.FLOAT;
|
||||||
//the allocation mode for the heap
|
//the allocation mode for the heap
|
||||||
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.HEAP;
|
public static DataBuffer.AllocationMode alloc = DataBuffer.AllocationMode.MIXED_DATA_TYPES;
|
||||||
public static char ORDER = 'c';
|
|
||||||
public static double EPS_THRESHOLD = 1e-5;
|
public static double EPS_THRESHOLD = 1e-5;
|
||||||
private static boolean allowsOrder = false;
|
private static boolean allowsOrder = false;
|
||||||
public static boolean compressDebug = false;
|
public static boolean compressDebug = false;
|
||||||
|
@ -157,45 +154,27 @@ public class Nd4j {
|
||||||
private static final AtomicInteger numThreads = new AtomicInteger(-1);
|
private static final AtomicInteger numThreads = new AtomicInteger(-1);
|
||||||
private static AtomicReference<DataType> defaultFloatingPointDataType;
|
private static AtomicReference<DataType> defaultFloatingPointDataType;
|
||||||
|
|
||||||
protected static Class<? extends MemoryWorkspaceManager> workspaceManagerClazz;
|
private static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
||||||
protected static Class<? extends BlasWrapper> blasWrapperClazz;
|
private static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
||||||
protected static Class<? extends BlasWrapper> sparseBlasWrapperClazz;
|
private static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
|
||||||
protected static Class<? extends NDArrayFactory> ndArrayFactoryClazz;
|
|
||||||
protected static Class<? extends NDArrayFactory> sparseNDArrayClazz;
|
|
||||||
protected static Class<? extends ConvolutionInstance> convolutionInstanceClazz;
|
|
||||||
protected static Class<? extends DataBufferFactory> dataBufferFactoryClazz;
|
|
||||||
protected static Class<? extends OpExecutioner> opExecutionerClazz;
|
|
||||||
protected static Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz;
|
|
||||||
protected static Class<? extends DistributionFactory> distributionFactoryClazz;
|
|
||||||
protected static Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz;
|
|
||||||
protected static Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz;
|
|
||||||
protected static Class<? extends BasicConstantHandler> constantProviderClazz;
|
|
||||||
protected static Class<? extends BasicAffinityManager> affinityManagerClazz;
|
|
||||||
protected static Class<? extends BasicMemoryManager> memoryManagerClazz;
|
|
||||||
|
|
||||||
protected static DataBufferFactory DATA_BUFFER_FACTORY_INSTANCE;
|
|
||||||
protected static BlasWrapper BLAS_WRAPPER_INSTANCE;
|
|
||||||
protected static BlasWrapper SPARSE_BLAS_WRAPPER_INSTANCE;
|
|
||||||
protected static NDArrayFactory INSTANCE;
|
protected static NDArrayFactory INSTANCE;
|
||||||
protected static NDArrayFactory SPARSE_INSTANCE;
|
private static NDArrayFactory SPARSE_INSTANCE;
|
||||||
protected static ConvolutionInstance CONVOLUTION_INSTANCE;
|
private static ConvolutionInstance CONVOLUTION_INSTANCE;
|
||||||
protected static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
private static OpExecutioner OP_EXECUTIONER_INSTANCE;
|
||||||
protected static DistributionFactory DISTRIBUTION_FACTORY;
|
private static DistributionFactory DISTRIBUTION_FACTORY;
|
||||||
protected static ShapeInfoProvider shapeInfoProvider;
|
private static ShapeInfoProvider shapeInfoProvider;
|
||||||
protected static SparseInfoProvider sparseInfoProvider;
|
private static SparseInfoProvider sparseInfoProvider;
|
||||||
protected static ConstantHandler constantHandler;
|
private static ConstantHandler constantHandler;
|
||||||
protected static AffinityManager affinityManager;
|
private static AffinityManager affinityManager;
|
||||||
protected static MemoryManager memoryManager;
|
private static MemoryManager memoryManager;
|
||||||
|
|
||||||
protected static AtomicBoolean fallbackMode;
|
private static AtomicBoolean fallbackMode;
|
||||||
|
|
||||||
protected static Properties props = new Properties();
|
protected static Properties props = new Properties();
|
||||||
protected static ReferenceQueue<INDArray> referenceQueue = new ReferenceQueue<>();
|
|
||||||
protected static ReferenceQueue<DataBuffer> bufferQueue = new ReferenceQueue<>();
|
|
||||||
|
|
||||||
private final static Logger logger = Logger.getLogger(Nd4j.class.getName());
|
private final static Logger logger = Logger.getLogger(Nd4j.class.getName());
|
||||||
|
|
||||||
protected static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
|
private static final INDArray[] EMPTY_ARRAYS = new INDArray[DataType.values().length];
|
||||||
|
|
||||||
static {
|
static {
|
||||||
fallbackMode = new AtomicBoolean(false);
|
fallbackMode = new AtomicBoolean(false);
|
||||||
|
@ -385,7 +364,6 @@ public class Nd4j {
|
||||||
* @param toShuffle the ndarray to shuffle
|
* @param toShuffle the ndarray to shuffle
|
||||||
* @param random the random to use
|
* @param random the random to use
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
|
public static void shuffle(INDArray toShuffle, Random random, @NonNull int... dimension) {
|
||||||
INSTANCE.shuffle(toShuffle, random, dimension);
|
INSTANCE.shuffle(toShuffle, random, dimension);
|
||||||
|
@ -396,10 +374,8 @@ public class Nd4j {
|
||||||
* along a specified set of dimensions
|
* along a specified set of dimensions
|
||||||
* @param toShuffle the ndarray to shuffle
|
* @param toShuffle the ndarray to shuffle
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
|
public static void shuffle(INDArray toShuffle, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
|
||||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -408,10 +384,8 @@ public class Nd4j {
|
||||||
* along a specified set of dimensions
|
* along a specified set of dimensions
|
||||||
* @param toShuffle the ndarray to shuffle
|
* @param toShuffle the ndarray to shuffle
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
|
public static void shuffle(Collection<INDArray> toShuffle, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
|
||||||
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
INSTANCE.shuffle(toShuffle, new Random(), dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,10 +394,8 @@ public class Nd4j {
|
||||||
* along a specified set of dimensions
|
* along a specified set of dimensions
|
||||||
* @param toShuffle the ndarray to shuffle
|
* @param toShuffle the ndarray to shuffle
|
||||||
* @param dimension the dimension to do the shuffle
|
* @param dimension the dimension to do the shuffle
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
|
public static void shuffle(Collection<INDArray> toShuffle, Random rnd, @NonNull int... dimension) {
|
||||||
//shuffle(toShuffle, new Random(), dimension);
|
|
||||||
INSTANCE.shuffle(toShuffle, rnd, dimension);
|
INSTANCE.shuffle(toShuffle, rnd, dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -433,33 +405,11 @@ public class Nd4j {
|
||||||
*
|
*
|
||||||
* @param toShuffle the ndarray to shuffle
|
* @param toShuffle the ndarray to shuffle
|
||||||
* @param dimensions the dimension to do the shuffle. Please note - order matters here.
|
* @param dimensions the dimension to do the shuffle. Please note - order matters here.
|
||||||
* @return
|
|
||||||
*/
|
*/
|
||||||
public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) {
|
public static void shuffle(List<INDArray> toShuffle, Random rnd, List<int[]> dimensions) {
|
||||||
|
|
||||||
INSTANCE.shuffle(toShuffle, rnd, dimensions);
|
INSTANCE.shuffle(toShuffle, rnd, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* The reference queue used for cleaning up
|
|
||||||
* ndarrays
|
|
||||||
*
|
|
||||||
* @return the reference queue for cleaning up ndarrays
|
|
||||||
*/
|
|
||||||
public static ReferenceQueue<INDArray> refQueue() {
|
|
||||||
return referenceQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The reference queue used for cleaning up
|
|
||||||
* databuffers
|
|
||||||
*
|
|
||||||
* @return the reference queue for cleaning up databuffers
|
|
||||||
*/
|
|
||||||
public static ReferenceQueue<DataBuffer> bufferRefQueue() {
|
|
||||||
return bufferQueue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the primary distributions
|
* Get the primary distributions
|
||||||
* factory
|
* factory
|
||||||
|
@ -480,9 +430,9 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns RandomFactory instance
|
* Get the RandomFactory instance
|
||||||
*
|
*
|
||||||
* @return
|
* @return the RandomFactory instance
|
||||||
*/
|
*/
|
||||||
public static RandomFactory getRandomFactory() {
|
public static RandomFactory getRandomFactory() {
|
||||||
return randomFactory;
|
return randomFactory;
|
||||||
|
@ -500,7 +450,7 @@ public class Nd4j {
|
||||||
/**
|
/**
|
||||||
* Set a convolution instance
|
* Set a convolution instance
|
||||||
*
|
*
|
||||||
* @param convolutionInstance
|
* @param convolutionInstance the new convolution instance
|
||||||
*/
|
*/
|
||||||
public static void setConvolution(ConvolutionInstance convolutionInstance) {
|
public static void setConvolution(ConvolutionInstance convolutionInstance) {
|
||||||
if (convolutionInstance == null)
|
if (convolutionInstance == null)
|
||||||
|
@ -526,7 +476,6 @@ public class Nd4j {
|
||||||
* slice is the specified shape
|
* slice is the specified shape
|
||||||
*/
|
*/
|
||||||
public static INDArray create(int[] sliceShape, float[]... arrays) {
|
public static INDArray create(int[] sliceShape, float[]... arrays) {
|
||||||
//TODO: Remove duplicate code.
|
|
||||||
int slices = arrays.length;
|
int slices = arrays.length;
|
||||||
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
|
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
|
||||||
for (int i = 0; i < ret.slices(); i++)
|
for (int i = 0; i < ret.slices(); i++)
|
||||||
|
@ -586,30 +535,6 @@ public class Nd4j {
|
||||||
return DATA_BUFFER_FACTORY_INSTANCE;
|
return DATA_BUFFER_FACTORY_INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Given a sequence of Iterators over a transform of matrices, fill in all of
|
|
||||||
* the matrices with the entries in the theta vector. Errors are
|
|
||||||
* thrown if the theta vector does not exactly fill the matrices.
|
|
||||||
*
|
|
||||||
* TODO: unused method.
|
|
||||||
*/
|
|
||||||
public static void setParams(INDArray theta, Collection<INDArray>... matrices) {
|
|
||||||
int index = 0;
|
|
||||||
for (Collection<INDArray> matrixCollection : matrices) {
|
|
||||||
for (INDArray matrix : matrixCollection) {
|
|
||||||
INDArray linear = matrix.reshape(-1);
|
|
||||||
for (int i = 0; i < matrix.length(); i++) {
|
|
||||||
linear.putScalar(i, theta.getDouble(index));
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (index != theta.length()) {
|
|
||||||
throw new AssertionError("Did not entirely use the theta vector");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Roll the specified axis backwards,
|
* Roll the specified axis backwards,
|
||||||
* until it lies in a given position.
|
* until it lies in a given position.
|
||||||
|
@ -681,7 +606,7 @@ public class Nd4j {
|
||||||
* @param b the right tensor
|
* @param b the right tensor
|
||||||
* @param result the result array
|
* @param result the result array
|
||||||
* @param axes the axes for each array to do matrix multiply along
|
* @param axes the axes for each array to do matrix multiply along
|
||||||
* @return
|
* @return the result array
|
||||||
*/
|
*/
|
||||||
public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) {
|
public static INDArray tensorMmul(INDArray a, INDArray b,INDArray result, int[][] axes) {
|
||||||
int validationLength = Math.min(axes[0].length, axes[1].length);
|
int validationLength = Math.min(axes[0].length, axes[1].length);
|
||||||
|
@ -720,15 +645,7 @@ public class Nd4j {
|
||||||
//if listA and listB are empty these donot initialize.
|
//if listA and listB are empty these donot initialize.
|
||||||
//so initializing with {1} which will then get overriden if not empty
|
//so initializing with {1} which will then get overriden if not empty
|
||||||
long[] newShapeA = {-1, n2};
|
long[] newShapeA = {-1, n2};
|
||||||
//TODO: remove duplicate code.
|
long[] oldShapeA = getOldShape(listA, a);
|
||||||
long[] oldShapeA;
|
|
||||||
if (listA.size() == 0) {
|
|
||||||
oldShapeA = new long[] {1};
|
|
||||||
} else {
|
|
||||||
oldShapeA = Longs.toArray(listA);
|
|
||||||
for (int i = 0; i < oldShapeA.length; i++)
|
|
||||||
oldShapeA[i] = a.size((int) oldShapeA[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
int n3 = 1;
|
int n3 = 1;
|
||||||
int bNax = Math.min(b.rank(), axes[1].length);
|
int bNax = Math.min(b.rank(), axes[1].length);
|
||||||
|
@ -737,14 +654,7 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
long[] newShapeB = {n3, -1};
|
long[] newShapeB = {n3, -1};
|
||||||
long[] oldShapeB;
|
long[] oldShapeB = getOldShape(listB, b);
|
||||||
if (listB.size() == 0) {
|
|
||||||
oldShapeB = new long[] {1};
|
|
||||||
} else {
|
|
||||||
oldShapeB = Longs.toArray(listB);
|
|
||||||
for (int i = 0; i < oldShapeB.length; i++)
|
|
||||||
oldShapeB[i] = b.size((int) oldShapeB[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray at = a.permute(newAxesA).reshape(newShapeA);
|
INDArray at = a.permute(newAxesA).reshape(newShapeA);
|
||||||
INDArray bt = b.permute(newAxesB).reshape(newShapeB);
|
INDArray bt = b.permute(newAxesB).reshape(newShapeB);
|
||||||
|
@ -754,6 +664,19 @@ public class Nd4j {
|
||||||
return ret.reshape(aPlusB);
|
return ret.reshape(aPlusB);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Some duplicate code that refactored out:
|
||||||
|
private static long[] getOldShape(List<Integer> list, INDArray x){
|
||||||
|
long[] res;
|
||||||
|
if (list.size() == 0) {
|
||||||
|
res = new long[] {1};
|
||||||
|
} else {
|
||||||
|
res= Longs.toArray(list);
|
||||||
|
for (int i = 0; i < res.length; i++)
|
||||||
|
res[i] = x.size((int) res[i]);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tensor matrix multiplication.
|
* Tensor matrix multiplication.
|
||||||
* Both tensors must be the same rank
|
* Both tensors must be the same rank
|
||||||
|
@ -761,7 +684,7 @@ public class Nd4j {
|
||||||
* @param a the left tensor
|
* @param a the left tensor
|
||||||
* @param b the right tensor
|
* @param b the right tensor
|
||||||
* @param axes the axes for each array to do matrix multiply along
|
* @param axes the axes for each array to do matrix multiply along
|
||||||
* @return
|
* @return the multiplication result.
|
||||||
*/
|
*/
|
||||||
public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) {
|
public static INDArray tensorMmul(INDArray a, INDArray b, int[][] axes) {
|
||||||
CustomOp op = DynamicCustomOp.builder("tensordot")
|
CustomOp op = DynamicCustomOp.builder("tensordot")
|
||||||
|
@ -892,29 +815,6 @@ public class Nd4j {
|
||||||
return matmul(a,b, null);
|
return matmul(a,b, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Given a sequence of Iterators over a transform of matrices, fill in all of
|
|
||||||
* the matrices with the entries in the theta vector. Errors are
|
|
||||||
* thrown if the theta vector does not exactly fill the matrices.
|
|
||||||
* TODO: unused method.
|
|
||||||
*/
|
|
||||||
public static void setParams(INDArray theta, Iterator<? extends INDArray>... matrices) {
|
|
||||||
int index = 0;
|
|
||||||
for (Iterator<? extends INDArray> matrixIterator : matrices) {
|
|
||||||
while (matrixIterator.hasNext()) {
|
|
||||||
INDArray matrix = matrixIterator.next().reshape(-1);
|
|
||||||
for (int i = 0; i < matrix.length(); i++) {
|
|
||||||
matrix.putScalar(i, theta.getDouble(index));
|
|
||||||
index++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (index != theta.length()) {
|
|
||||||
throw new AssertionError("Did not entirely use the theta vector");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The factory used for creating ndarrays
|
* The factory used for creating ndarrays
|
||||||
*
|
*
|
||||||
|
@ -5766,45 +5666,45 @@ public class Nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
|
compressDebug = pp.toBoolean(COMPRESSION_DEBUG);
|
||||||
ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
char ORDER = pp.toChar(ORDER_KEY, NDArrayFactory.C);
|
||||||
|
|
||||||
affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
|
Class<? extends BasicAffinityManager> affinityManagerClazz = (Class<? extends BasicAffinityManager>) Class
|
||||||
.forName(pp.toString(AFFINITY_MANAGER));
|
.forName(pp.toString(AFFINITY_MANAGER));
|
||||||
affinityManager = affinityManagerClazz.newInstance();
|
affinityManager = affinityManagerClazz.newInstance();
|
||||||
ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
Class<? extends NDArrayFactory> ndArrayFactoryClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||||
pp.toString(NDARRAY_FACTORY_CLASS));
|
pp.toString(NDARRAY_FACTORY_CLASS));
|
||||||
sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
Class<? extends NDArrayFactory> sparseNDArrayClazz = (Class<? extends NDArrayFactory>) Class.forName(
|
||||||
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
|
pp.toString(SPARSE_NDARRAY_FACTORY_CLASS));
|
||||||
convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
Class<? extends ConvolutionInstance> convolutionInstanceClazz = (Class<? extends ConvolutionInstance>) Class
|
||||||
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
.forName(pp.toString(CONVOLUTION_OPS, DefaultConvolutionInstance.class.getName()));
|
||||||
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
String defaultName = pp.toString(DATA_BUFFER_OPS, DefaultDataBufferFactory.class.getName());
|
||||||
dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
Class<? extends DataBufferFactory> dataBufferFactoryClazz = (Class<? extends DataBufferFactory>) Class
|
||||||
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
.forName(pp.toString(DATA_BUFFER_OPS, defaultName));
|
||||||
shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
Class<? extends BaseShapeInfoProvider> shapeInfoProviderClazz = (Class<? extends BaseShapeInfoProvider>) Class
|
||||||
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
.forName(pp.toString(SHAPEINFO_PROVIDER));
|
||||||
sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
|
Class<? extends BaseSparseInfoProvider> sparseInfoProviderClazz = (Class<? extends BaseSparseInfoProvider>) Class.forName(
|
||||||
pp.toString(SPARSEINFO_PROVIDER));
|
pp.toString(SPARSEINFO_PROVIDER));
|
||||||
|
|
||||||
constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
Class<? extends BasicConstantHandler> constantProviderClazz = (Class<? extends BasicConstantHandler>) Class
|
||||||
.forName(pp.toString(CONSTANT_PROVIDER));
|
.forName(pp.toString(CONSTANT_PROVIDER));
|
||||||
|
|
||||||
memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
|
Class<? extends BasicMemoryManager> memoryManagerClazz = (Class<? extends BasicMemoryManager>) Class
|
||||||
.forName(pp.toString(MEMORY_MANAGER));
|
.forName(pp.toString(MEMORY_MANAGER));
|
||||||
|
|
||||||
allowsOrder = backend.allowsOrder();
|
allowsOrder = backend.allowsOrder();
|
||||||
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
|
String rand = pp.toString(RANDOM_PROVIDER, DefaultRandom.class.getName());
|
||||||
randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
|
Class<? extends org.nd4j.linalg.api.rng.Random> randomClazz = (Class<? extends org.nd4j.linalg.api.rng.Random>) Class.forName(rand);
|
||||||
randomFactory = new RandomFactory(randomClazz);
|
randomFactory = new RandomFactory(randomClazz);
|
||||||
|
|
||||||
workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
|
Class<? extends MemoryWorkspaceManager> workspaceManagerClazz = (Class<? extends MemoryWorkspaceManager>) Class
|
||||||
.forName(pp.toString(WORKSPACE_MANAGER));
|
.forName(pp.toString(WORKSPACE_MANAGER));
|
||||||
|
|
||||||
blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
Class<? extends BlasWrapper> blasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||||
.forName(pp.toString(BLAS_OPS));
|
.forName(pp.toString(BLAS_OPS));
|
||||||
sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
Class<? extends BlasWrapper> sparseBlasWrapperClazz = (Class<? extends BlasWrapper>) Class
|
||||||
.forName(pp.toString(SPARSE_BLAS_OPS));
|
.forName(pp.toString(SPARSE_BLAS_OPS));
|
||||||
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
String clazzName = pp.toString(DISTRIBUTION, DefaultDistributionFactory.class.getName());
|
||||||
distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
Class<? extends DistributionFactory> distributionFactoryClazz = (Class<? extends DistributionFactory>) Class.forName(clazzName);
|
||||||
|
|
||||||
|
|
||||||
memoryManager = memoryManagerClazz.newInstance();
|
memoryManager = memoryManagerClazz.newInstance();
|
||||||
|
@ -5813,7 +5713,7 @@ public class Nd4j {
|
||||||
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
|
sparseInfoProvider = sparseInfoProviderClazz.newInstance();
|
||||||
workspaceManager = workspaceManagerClazz.newInstance();
|
workspaceManager = workspaceManagerClazz.newInstance();
|
||||||
|
|
||||||
opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
Class<? extends OpExecutioner> opExecutionerClazz = (Class<? extends OpExecutioner>) Class
|
||||||
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
.forName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName()));
|
||||||
|
|
||||||
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance();
|
||||||
|
|
Loading…
Reference in New Issue