Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14) * Modified strided_slice op to properly work with empty-like shapes. * Fixed test for reduce_mean with empty-like input. * [WIP] Last merge (#15) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks * [WIP] Fixing outstanding issues for NLP (#9) * Avoid using not-inited objects * Test fixed. * Redundant method avoided for models like FastText * KMeans++ implementation * KMeans++ implementation * Disable parallel execution * KMeans++ * Tests * Dev branch merge (#16) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Fix some issues on master (#17) * Fix DataVec test issue * Fix issue with dl4j SameDiff output layer * Dtype fix for lambda layers * #7912 BertIterator dtype fix (use float32 not global default) * [WIP] Next set of CUDA stuff (#7) New CUDA implementations and improvements * bad file * Dev branch master merge (#23) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Compatibility of deserialization (#18) Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * SameDiff: add activation gradient checking support for debugging (#19) * SameDiff gradient checker: first pass on activation gradient checks * Fixes + tests for activation gradient checking * Javadoc * [WIP] Some nd4j data type corrections (#20) * Adjust data type * Set correct Data type. * Size of proper data type. * fix averaged cpu load (#22) * SameDiff ops, TF import and fixes (#24) * CheckNumerics tests + fixes + misc fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fake quant Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * FakeQuantWithMinMaxArgs Signed-off-by: AlexDBlack <blacka101@gmail.com> * CheckNumerics fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Exception tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for out of scope stack allocated var use Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ignore for known failing test (already logged issue) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Merge upstream to fork (#25) * Add thousand-separator commas to TotalParams (#7915) * Add thousand-separator commas to TotalParams The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them. * Add thousand-separator commas to MultiLayerNetwork Corresponding change to MultiLayerNetwork Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com> * Update contributing and issue/PR templates (#7934) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix link to AdaDelta paper (#7942) Fix link to AdaDelta paper hosted on matthewzeiler.com Signed-off-by: Jxtps * Fixes, and ignores for known/logged failing issues (#7943) Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff + DL4J/SameDiff: Multiple fixes (#28) * #7919 HDF5 attribute buffer length fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7909 Arbiter constructor exception ux improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7925 RNN output layer length checks Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7939 Add listener for validating inputs are not incorrectly modified Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7939 Integrate NonInplaceValidationListener into tests * #7844 DL4J SameDiff fixes for variable minibatch size * DL4J SameDiff fixes - ensure gradient for input placeholder is available Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweaks to ExternalErrorsFunction - use placeholders, make more robust * Another fix * More fixes * More SameDiff/DL4J fixes * Scope out scalar array creation in BaseScalarOp * Remove debug code Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Final dev branch merge (#29) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Compatibility of deserialization (#18) Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * SameDiff: add activation gradient checking support for debugging (#19) * SameDiff gradient checker: first pass on activation gradient checks * Fixes + tests for activation gradient checking * Javadoc * [WIP] Some nd4j data type corrections (#20) * Adjust data type * Set correct Data type. * Size of proper data type. * fix averaged cpu load (#22) * [WIP] Multiple dataset iterators (#27) * Splitting dataset into arbitrary number * Fixes * Multiple split of iterator * Test * Test * Some fixes * signature change * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * one more test for sequential use of DataSetIteratorSplitter Signed-off-by: raver119 <raver119@gmail.com> * Fixes * Fixes * one more test for Alexander Signed-off-by: raver119 <raver119@gmail.com> * Some fixes * Some fixes * one more test for Alexander Signed-off-by: raver119 <raver119@gmail.com> * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Some fixes * Some fixes * couple of assertions tweaked Signed-off-by: raver119 <raver119@gmail.com> * MDS splitter test :/ Signed-off-by: raver119 <raver119@gmail.com> * Minor refactoring * Multi dataset * Some fixes * More tests * Small number of test fixes/improvements (failures on CI) (#31) Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] More CUDA stuff (#26) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * LRN BP CUDA Signed-off-by: raver119 <raver119@gmail.com> * less memory Signed-off-by: raver119 <raver119@gmail.com> * Fixed bug with crop_and_resize op helper. * get rid of unnecessary index-calculation dunction Signed-off-by: Yurii <yurii@skymind.io> * Fixed sort with nth_element cuda-based helper. * Refactored nth_element. * Refactored nth_element op and tests. * Modified usage of dim array with sortTad routine. * Refactored main routine of helper for non_max_image_suppression op. * non_max_image_suppression op helper with cuda kernel implementation. Initial revision. * fix vol2col cuda kernel * meh Signed-off-by: raver119 <raver119@gmail.com> * topK concept Signed-off-by: raver119 <raver119@gmail.com> * unsorted topK with scanWitdh of 1 Signed-off-by: raver119 <raver119@gmail.com> * correct vol2col tests * sorted/unsorted topK Signed-off-by: raver119 <raver119@gmail.com> * implementation and fixing col2im/col2vol * Corrected usage flags with input/output with reverse op. * dup is const now Signed-off-by: raver119 <raver119@gmail.com> * percentile op Signed-off-by: raver119 <raver119@gmail.com> * group tests for mapool2d Signed-off-by: Yurii <yurii@skymind.io> * special test for george Signed-off-by: raver119 <raver119@gmail.com> * less threads for sortTad Signed-off-by: raver119 <raver119@gmail.com> * provide conv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * remove auther in sort tad kernel code Signed-off-by: Yurii <yurii@skymind.io> * provide depthwise_conv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * - max_pooling_with_argmax - null check for special use Signed-off-by: raver119 <raver119@gmail.com> * dts cuda Signed-off-by: raver119 <raver119@gmail.com> * provide sconv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * std cuda Signed-off-by: raver119 <raver119@gmail.com> * Refactored non_max_suppression op to conform TF implementation. * Improved suppression helper. * provide pooling3d for cuda Signed-off-by: Yurii <yurii@skymind.io> * minor lstm rearrangements Signed-off-by: raver119 <raver119@gmail.com> * more of minor lstm rearrangements Signed-off-by: raver119 <raver119@gmail.com> * (bi)dynamic_rnn Signed-off-by: raver119 <raver119@gmail.com> * templates init order Signed-off-by: raver119 <raver119@gmail.com> * Refactored non_max_suppression op. * Added cuda kernel for non_max_suppression. * CPU sort by key/value Signed-off-by: raver119 <raver119@gmail.com> * CPU sort TAD by key/value Signed-off-by: raver119 <raver119@gmail.com> * CPU sort TAD by key/value tests Signed-off-by: raver119 <raver119@gmail.com> * Eliminate compiler error with cuda implementation. * - repaired gradCheck in cuda - provide conv2d_bp for cuda Signed-off-by: Yurii <yurii@skymind.io> * missed signature Signed-off-by: raver119 <raver119@gmail.com> * provide depthwise_conv2d_bp for cuda Signed-off-by: Yurii <yurii@skymind.io> * Implementation of lup helper with cuda kernel. Initial commit. * further work on backprops for convolutions Signed-off-by: Yurii <yurii@skymind.io> * CUDA linear sort by key/val Signed-off-by: raver119 <raver119@gmail.com> * CUDA tad sort by key/val Signed-off-by: raver119 <raver119@gmail.com> * start providing of backprop for pooling2d/3d Signed-off-by: Yurii <yurii@skymind.io> * Added atomicAdd for bool datatype. * dynamic partition concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic partition concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic partition scalar CUDA Signed-off-by: raver119 <raver119@gmail.com> * important comment Signed-off-by: raver119 <raver119@gmail.com> * fix pooling2d/3d backprop helpers Signed-off-by: Yurii <yurii@skymind.io> * Added non-linear test with dynamic_partition. * Improved test for dynamic_partition. * dynamic_partition TAD concept Signed-off-by: raver119 <raver119@gmail.com> * - dynamic_partition TAD CUDA impl - dynamic_partition TAD CPU fix Signed-off-by: raver119 <raver119@gmail.com> * - rewrite cpu code for usampling2d/3d - write cuda code for usampling2d/3d Signed-off-by: Yurii <yurii@skymind.io> * dynamic_stitch CUDA vector case Signed-off-by: raver119 <raver119@gmail.com> * dynamic_stitch CUDA TAD case concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic_stitch CUDA TAD case impl Signed-off-by: raver119 <raver119@gmail.com> * Added tests for dynamic_stitch 3D-4D cases. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * Fixed type check for dynamic stitch. * min/max bp Signed-off-by: raver119 <raver119@gmail.com> * rewrite code for upsampling2d/3d cpu Signed-off-by: Yurii <yurii@skymind.io> * reduce min/max/norm_max bp Signed-off-by: raver119 <raver119@gmail.com> * lup implementation. Additional enhancements. * provide code for upsamling2d/3d backprop Signed-off-by: Yurii <yurii@skymind.io> * weightedCrossEntropyWithLogits Signed-off-by: raver119 <raver119@gmail.com> * Fixed template math atomicMul for 64bit ints. * Refactored dynamic_partition_bp op. * inverseBroadcast fix Signed-off-by: raver119 <raver119@gmail.com> * DynamicPartitionBP test datatype fixed. * - nd4j_atomicMul Windows fix - cpu/NDArrayLambda.hpp excluded from CUDA Signed-off-by: raver119 <raver119@gmail.com>master
parent
cae4fc9760
commit
1170827c18
|
@ -31,7 +31,7 @@ public class TaskCreatorProvider {
|
|||
}
|
||||
return c.newInstance();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Could not create new instance of task creator class: " + c, e);
|
||||
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
|
|||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||
return clazz.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
|
|||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||
return clazz.newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ public abstract class BaseNetScoreFunction implements ScoreFunction {
|
|||
ds.configure(dataSourceProperties);
|
||||
}
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e);
|
||||
}
|
||||
return score(model, ds.testData());
|
||||
}
|
||||
|
|
|
@ -188,10 +188,15 @@ public class ComputationGraphTaskCreator implements TaskCreator {
|
|||
//For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both
|
||||
MultiDataSetIterator iterator;
|
||||
if(dataSource != null){
|
||||
DataSource dsInstance = dataSource.newInstance();
|
||||
if(dataSourceProperties != null)
|
||||
dsInstance.configure(dataSourceProperties);
|
||||
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
|
||||
try {
|
||||
DataSource dsInstance = dataSource.newInstance();
|
||||
if (dataSourceProperties != null)
|
||||
dsInstance.configure(dataSourceProperties);
|
||||
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
|
||||
" - no zero-arg constructor?",e);
|
||||
}
|
||||
} else {
|
||||
iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters()));
|
||||
}
|
||||
|
|
|
@ -190,7 +190,8 @@ public class MultiLayerNetworkTaskCreator implements TaskCreator {
|
|||
try{
|
||||
dsInstance = dataSource.newInstance();
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName());
|
||||
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
|
||||
" - no zero-arg constructor?",e);
|
||||
}
|
||||
if(dataSourceProperties != null)
|
||||
dsInstance.configure(dataSourceProperties);
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
@ -78,14 +79,14 @@ public class TestNDArrayWritableTransforms {
|
|||
assertEquals(expColNames, tp.getFinalSchema().getColumnNames());
|
||||
|
||||
|
||||
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
|
||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)));
|
||||
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
|
||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)));
|
||||
List<Writable> out = tp.execute(in);
|
||||
|
||||
List<Writable> exp =
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
|
||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)),
|
||||
new NDArrayWritable(Nd4j.linspace(0, 9, 10).addi(2.0)));
|
||||
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
|
||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)),
|
||||
new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE, 0, 10, 1).addi(2.0).reshape(1,10)));
|
||||
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
|
|
|
@ -20,9 +20,15 @@ import lombok.val;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class DataSetSplitterTests extends BaseDL4JTest {
|
||||
@Test
|
||||
|
@ -39,7 +45,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
|||
int gcntTest = 0;
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++){
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int cnt = 0;
|
||||
while (train.hasNext()) {
|
||||
val data = train.next().getFeatures();
|
||||
|
@ -79,7 +85,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
|||
int gcntTest = 0;
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++){
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int cnt = 0;
|
||||
while (train.hasNext()) {
|
||||
val data = train.next().getFeatures();
|
||||
|
@ -117,7 +123,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
|||
int gcntTest = 0;
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++){
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int cnt = 0;
|
||||
while (train.hasNext()) {
|
||||
val data = train.next().getFeatures();
|
||||
|
@ -144,4 +150,245 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
|||
|
||||
assertEquals(1000 * numEpochs, global);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSplitter_4() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new DataSetIteratorSplitter(back, 1000, new double[]{0.5, 0.3, 0.2});
|
||||
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||
val numEpochs = 10;
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int iterNo = 0;
|
||||
int perEpoch = 0;
|
||||
for (val partIterator : iteratorList) {
|
||||
int cnt = 0;
|
||||
partIterator.reset();
|
||||
while (partIterator.hasNext()) {
|
||||
val data = partIterator.next().getFeatures();
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||
(float) perEpoch, data.getFloat(0), 1e-5);
|
||||
//gcntTrain++;
|
||||
global++;
|
||||
cnt++;
|
||||
++perEpoch;
|
||||
}
|
||||
++iterNo;
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(1000* numEpochs, global);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSplitter_5() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{900, 100});
|
||||
|
||||
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||
val numEpochs = 10;
|
||||
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int iterNo = 0;
|
||||
int perEpoch = 0;
|
||||
for (val partIterator : iteratorList) {
|
||||
partIterator.reset();
|
||||
while (partIterator.hasNext()) {
|
||||
int cnt = 0;
|
||||
val data = partIterator.next().getFeatures();
|
||||
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||
(float) perEpoch, data.getFloat(0), 1e-5);
|
||||
//gcntTrain++;
|
||||
global++;
|
||||
cnt++;
|
||||
++perEpoch;
|
||||
}
|
||||
++iterNo;
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(1000 * numEpochs, global);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSplitter_6() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
// we're going to mimic train+test+validation split
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||
|
||||
assertEquals(3, splitter.getIterators().size());
|
||||
|
||||
val trainIter = splitter.getIterators().get(0);
|
||||
val testIter = splitter.getIterators().get(1);
|
||||
val validationIter = splitter.getIterators().get(2);
|
||||
|
||||
// we're going to have multiple epochs
|
||||
int numEpochs = 10;
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int globalIter = 0;
|
||||
trainIter.reset();
|
||||
testIter.reset();
|
||||
validationIter.reset();
|
||||
|
||||
boolean trained = false;
|
||||
while (trainIter.hasNext()) {
|
||||
trained = true;
|
||||
val ds = trainIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||
assertEquals(800, globalIter);
|
||||
|
||||
|
||||
// test set is used every epoch
|
||||
boolean tested = false;
|
||||
//testIter.reset();
|
||||
while (testIter.hasNext()) {
|
||||
tested = true;
|
||||
val ds = testIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||
assertEquals(900, globalIter);
|
||||
|
||||
// validation set is used every 5 epochs
|
||||
if (e % 5 == 0) {
|
||||
boolean validated = false;
|
||||
//validationIter.reset();
|
||||
while (validationIter.hasNext()) {
|
||||
validated = true;
|
||||
val ds = validationIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||
}
|
||||
|
||||
// all 3 iterators have exactly 1000 elements combined
|
||||
if (e % 5 == 0)
|
||||
assertEquals(1000, globalIter);
|
||||
else
|
||||
assertEquals(900, globalIter);
|
||||
trainIter.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_1() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{500, 500});
|
||||
|
||||
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||
val numEpochs = 10;
|
||||
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
|
||||
// Get data from second part, then rewind for the first one.
|
||||
int cnt = 0;
|
||||
int partNumber = 1;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
|
||||
cnt++;
|
||||
global++;
|
||||
}
|
||||
iteratorList.get(partNumber).reset();
|
||||
partNumber = 0;
|
||||
cnt = 0;
|
||||
while (iteratorList.get(0).hasNext()) {
|
||||
val data = iteratorList.get(0).next().getFeatures();
|
||||
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
||||
global++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_2() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{2});
|
||||
|
||||
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||
|
||||
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
|
||||
int cnt = 0;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
|
||||
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_3() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{10});
|
||||
|
||||
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||
Random random = new Random();
|
||||
int[] indexes = new int[iteratorList.size()];
|
||||
for (int i = 0; i < indexes.length; ++i) {
|
||||
indexes[i] = random.nextInt(iteratorList.size());
|
||||
}
|
||||
|
||||
for (int partNumber : indexes) {
|
||||
int cnt = 0;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
|
||||
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_4() {
|
||||
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
// we're going to mimic train+test+validation split
|
||||
val splitter = new DataSetIteratorSplitter(back, new int[]{80, 10, 5});
|
||||
|
||||
assertEquals(3, splitter.getIterators().size());
|
||||
|
||||
val trainIter = splitter.getIterators().get(0); // 0..79
|
||||
val testIter = splitter.getIterators().get(1); // 80 ..89
|
||||
val validationIter = splitter.getIterators().get(2); // 90..94
|
||||
|
||||
// we're skipping train/test and go for validation first. we're that crazy, right.
|
||||
int valCnt = 0;
|
||||
while (validationIter.hasNext()) {
|
||||
val ds = validationIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
|
||||
valCnt++;
|
||||
}
|
||||
assertEquals(5, valCnt);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,11 +18,17 @@ package org.deeplearning4j.datasets.iterator;
|
|||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -150,4 +156,309 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
|||
|
||||
assertEquals(1000 * numEpochs, global);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiSplitter_1() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
// we're going to mimic train+test+validation split
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||
|
||||
assertEquals(3, splitter.getIterators().size());
|
||||
|
||||
val trainIter = splitter.getIterators().get(0);
|
||||
val testIter = splitter.getIterators().get(1);
|
||||
val validationIter = splitter.getIterators().get(2);
|
||||
|
||||
// we're going to have multiple epochs
|
||||
int numEpochs = 10;
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int globalIter = 0;
|
||||
trainIter.reset();
|
||||
testIter.reset();
|
||||
validationIter.reset();
|
||||
|
||||
boolean trained = false;
|
||||
while (trainIter.hasNext()) {
|
||||
trained = true;
|
||||
val ds = trainIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||
assertEquals(800, globalIter);
|
||||
|
||||
|
||||
// test set is used every epoch
|
||||
boolean tested = false;
|
||||
//testIter.reset();
|
||||
while (testIter.hasNext()) {
|
||||
tested = true;
|
||||
val ds = testIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||
assertEquals(900, globalIter);
|
||||
|
||||
// validation set is used every 5 epochs
|
||||
if (e % 5 == 0) {
|
||||
boolean validated = false;
|
||||
//validationIter.reset();
|
||||
while (validationIter.hasNext()) {
|
||||
validated = true;
|
||||
val ds = validationIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||
}
|
||||
|
||||
// all 3 iterators have exactly 1000 elements combined
|
||||
if (e % 5 == 0)
|
||||
assertEquals(1000, globalIter);
|
||||
else
|
||||
assertEquals(900, globalIter);
|
||||
trainIter.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSplitter_5() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{900, 100});
|
||||
|
||||
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||
val numEpochs = 10;
|
||||
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int iterNo = 0;
|
||||
int perEpoch = 0;
|
||||
for (val partIterator : iteratorList) {
|
||||
partIterator.reset();
|
||||
while (partIterator.hasNext()) {
|
||||
int cnt = 0;
|
||||
val data = partIterator.next().getFeatures();
|
||||
|
||||
for (int i = 0; i < data.length; ++i) {
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||
(float) perEpoch, data[i].getFloat(0), 1e-5);
|
||||
}
|
||||
//gcntTrain++;
|
||||
global++;
|
||||
cnt++;
|
||||
++perEpoch;
|
||||
}
|
||||
++iterNo;
|
||||
}
|
||||
}
|
||||
|
||||
assertEquals(1000 * numEpochs, global);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSplitter_6() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
// we're going to mimic train+test+validation split
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||
|
||||
assertEquals(3, splitter.getIterators().size());
|
||||
|
||||
val trainIter = splitter.getIterators().get(0);
|
||||
val testIter = splitter.getIterators().get(1);
|
||||
val validationIter = splitter.getIterators().get(2);
|
||||
|
||||
// we're going to have multiple epochs
|
||||
int numEpochs = 10;
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
int globalIter = 0;
|
||||
trainIter.reset();
|
||||
testIter.reset();
|
||||
validationIter.reset();
|
||||
|
||||
boolean trained = false;
|
||||
while (trainIter.hasNext()) {
|
||||
trained = true;
|
||||
val ds = trainIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
||||
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||
assertEquals(800, globalIter);
|
||||
|
||||
|
||||
// test set is used every epoch
|
||||
boolean tested = false;
|
||||
//testIter.reset();
|
||||
while (testIter.hasNext()) {
|
||||
tested = true;
|
||||
val ds = testIter.next();
|
||||
assertNotNull(ds);
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||
assertEquals(900, globalIter);
|
||||
|
||||
// validation set is used every 5 epochs
|
||||
if (e % 5 == 0) {
|
||||
boolean validated = false;
|
||||
//validationIter.reset();
|
||||
while (validationIter.hasNext()) {
|
||||
validated = true;
|
||||
val ds = validationIter.next();
|
||||
assertNotNull(ds);
|
||||
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
||||
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||
}
|
||||
globalIter++;
|
||||
}
|
||||
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||
}
|
||||
|
||||
// all 3 iterators have exactly 1000 elements combined
|
||||
if (e % 5 == 0)
|
||||
assertEquals(1000, globalIter);
|
||||
else
|
||||
assertEquals(900, globalIter);
|
||||
trainIter.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_1() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{500, 500});
|
||||
|
||||
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||
val numEpochs = 10;
|
||||
|
||||
int global = 0;
|
||||
// emulating epochs here
|
||||
for (int e = 0; e < numEpochs; e++) {
|
||||
|
||||
// Get data from second part, then rewind for the first one.
|
||||
int cnt = 0;
|
||||
int partNumber = 1;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
for (int i = 0; i < data.length; ++i) {
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
|
||||
}
|
||||
cnt++;
|
||||
global++;
|
||||
}
|
||||
iteratorList.get(partNumber).reset();
|
||||
partNumber = 0;
|
||||
cnt = 0;
|
||||
while (iteratorList.get(0).hasNext()) {
|
||||
val data = iteratorList.get(0).next().getFeatures();
|
||||
for (int i = 0; i < data.length; ++i) {
|
||||
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
|
||||
data[i].getFloat(0), 1e-5);
|
||||
}
|
||||
global++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_2() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{2});
|
||||
|
||||
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||
|
||||
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
|
||||
int cnt = 0;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
for (int i = 0; i < data.length; ++i) {
|
||||
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_3() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10});
|
||||
|
||||
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||
Random random = new Random();
|
||||
int[] indexes = new int[iteratorList.size()];
|
||||
for (int i = 0; i < indexes.length; ++i) {
|
||||
indexes[i] = random.nextInt(iteratorList.size());
|
||||
}
|
||||
|
||||
for (int partNumber : indexes) {
|
||||
int cnt = 0;
|
||||
while (iteratorList.get(partNumber).hasNext()) {
|
||||
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||
for (int i = 0; i < data.length; ++i) {
|
||||
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
|
||||
data[i].getFloat(0), 1e-5);
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUnorderedSplitter_4() {
|
||||
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||
|
||||
// we're going to mimic train+test+validation split
|
||||
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{80, 10, 5});
|
||||
|
||||
assertEquals(3, splitter.getIterators().size());
|
||||
|
||||
val trainIter = splitter.getIterators().get(0); // 0..79
|
||||
val testIter = splitter.getIterators().get(1); // 80 ..89
|
||||
val validationIter = splitter.getIterators().get(2); // 90..94
|
||||
|
||||
// we're skipping train/test and go for validation first. we're that crazy, right.
|
||||
int valCnt = 0;
|
||||
while (validationIter.hasNext()) {
|
||||
val ds = validationIter.next();
|
||||
assertNotNull(ds);
|
||||
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
|
||||
ds.getFeatures()[i].getFloat(0), 1e-5);
|
||||
}
|
||||
valCnt++;
|
||||
}
|
||||
assertEquals(5, valCnt);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
|||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
|
@ -196,4 +197,43 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMismatchedInputLabelLength(){
|
||||
|
||||
for( int i=0; i<2; i++ ){
|
||||
|
||||
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
||||
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
|
||||
|
||||
switch (i){
|
||||
case 0:
|
||||
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
|
||||
break;
|
||||
case 1:
|
||||
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
MultiLayerConfiguration conf = lb.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
|
||||
|
||||
try{
|
||||
net.fit(in,l);
|
||||
} catch (Throwable t){
|
||||
String msg = t.getMessage();
|
||||
assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -249,7 +249,6 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
@Ignore("AB 2019/05/31 - Failing on CI and locally - see issues 7820 and 7657")
|
||||
public void testCorrectness1() {
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
|
@ -270,30 +269,18 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
|
|||
.useAdaGrad(false).build();
|
||||
|
||||
b.fit(data);
|
||||
System.out.println(b.getData());
|
||||
|
||||
/*double[] expectedData = new double[]{15.5392794313924, 19.25226403656672, -5.194955746137196, -31.787679714614757, 48.8674725273665,
|
||||
24.92775755686273, -22.621939920239065, -29.790772278125395, 19.027362415188914, -16.013800175884274,
|
||||
-27.454680593309185, 1.2929960811295493, -40.45000061571038, 61.23261682914338, 5.62278768938746,
|
||||
-28.16665244970911, -20.05502814088798, 12.803274346870865, -24.877262522905497, 45.115883138175874,
|
||||
21.597495694710616, 18.63254779638783, -4.029728632528419, -0.4596087279592638, -42.35340705500429,
|
||||
-69.24727547461491, 40.94332685199673, -24.60866142208024, 17.689874972878723, -3.6779759693605314,
|
||||
-30.91803590368529, 10.645452930824145, 36.58583235020565, -64.74975614289316, -39.364099390585956,
|
||||
72.54886481127016, -35.30663155696714, 19.37116912936714, -7.790876543092118, 19.6586396288508,
|
||||
58.1332709511154, -18.49217368496203, -3.5050200971182424, 5.662891294031322, 39.69533295638775,
|
||||
-15.114610550011662, -32.42366951357609, 17.039297537056537, 42.25610885633673, -2.7013781552769904,
|
||||
-16.338582630617925, 41.734027526336874, 20.941332646863426, -3.2145240561108244, -45.36033539684912};*/
|
||||
double[] expectedData = {40.93810899235225, 50.90183660191448, -14.298857560948981, -86.2012232604988, 129.51281793466023,
|
||||
66.29136854264247, -61.650213611972326, -80.42836756633497, 50.28325210727952, -44.29008119040566,
|
||||
-74.82748570869279, 2.0170536250746807, -109.21462846594635, 162.3973196127918, 14.000621153511705,
|
||||
-76.30892822919527, -54.251704596942275, 33.99763310539589, -67.6307009607032, 119.50868525237786,
|
||||
57.17786598853867, 49.1489174572297, -11.25663463504983, -2.38899196609398, -114.27194947404686,
|
||||
-185.93832011474473, 108.9022579845252, -66.14099037301474, 47.13683038425694, -10.037893631405792,
|
||||
-83.88458799629637, 26.985651418254996, 96.68139337135332, -174.2832443285551, -106.0999118697521,
|
||||
193.02622700008175, -94.88003359113081, 51.39502524568139, -20.96021960048648, 52.32291574424741,
|
||||
154.33973608321477, -50.90644802585217, -10.345744416395354, 13.721222143380892, 105.2111073677489,
|
||||
-41.339268919407345, -87.73042354938127, 45.306865238870046, 112.53877133856602, -8.44454352074299,
|
||||
-44.660828600669056, 110.72662022978719, 55.74660833987147, -9.613556053471232, -122.19953914048916};
|
||||
double[] expectedData = new double[]{ 63.8206, 80.4013, -19.4424, -140.4326, 198.7239,
|
||||
106.1148, -96.6273, -124.3634, 78.4174, -83.6621,
|
||||
-121.8706, 3.0888, -172.8560, 255.1262, 20.7021,
|
||||
-120.7942, -78.1829, 56.6021, -112.3294, 185.4084,
|
||||
88.5330, 78.0497, -18.8673, -11.0155, -175.1564,
|
||||
-297.8463, 174.2511, -103.8793, 72.5455, -15.8498,
|
||||
-134.5235, 42.3300, 154.0391, -280.1010, -167.9765,
|
||||
306.9938, -150.9666, 83.4419, -36.0877, 83.9992,
|
||||
245.1813, -81.5018, -14.8430, 16.1557, 166.8651,
|
||||
-65.9247, -138.1783, 72.5444, 176.3088, -25.6732,
|
||||
-69.6843, 167.3360, 87.6238, -18.5874, -187.3806};
|
||||
|
||||
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
|
||||
for (int i = 0; i < expectedArray.rows(); ++i)
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.util;
|
|||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -30,7 +31,7 @@ public class TimeSeriesUtilsTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testMovingAverage() {
|
||||
INDArray a = Nd4j.arange(0, 20);
|
||||
INDArray a = Nd4j.arange(0, 20).castTo(DataType.DOUBLE);
|
||||
INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f,
|
||||
12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f});
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
|||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
@ -42,14 +43,20 @@ public class DataSetIteratorSplitter {
|
|||
protected DataSetIterator backedIterator;
|
||||
protected final long totalExamples;
|
||||
protected final double ratio;
|
||||
protected final double[] ratios;
|
||||
protected final long numTrain;
|
||||
protected final long numTest;
|
||||
protected final long numArbitrarySets;
|
||||
protected final int[] splits;
|
||||
|
||||
|
||||
protected AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||
protected DataSet firstTrain = null;
|
||||
|
||||
protected int partNumber = 0;
|
||||
|
||||
/**
|
||||
* The only constructor
|
||||
*
|
||||
|
@ -71,17 +78,94 @@ public class DataSetIteratorSplitter {
|
|||
this.backedIterator = baseIterator;
|
||||
this.totalExamples = totalBatches;
|
||||
this.ratio = ratio;
|
||||
this.ratios = null;
|
||||
this.numTrain = (long) (totalExamples * ratio);
|
||||
this.numTest = totalExamples - numTrain;
|
||||
this.numArbitrarySets = 2;
|
||||
this.splits = null;
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double[] ratios) {
|
||||
for (double ratio : ratios) {
|
||||
if (!(ratio > 0.0 && ratio < 1.0))
|
||||
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
|
||||
}
|
||||
|
||||
if (totalBatches < 0)
|
||||
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||
|
||||
if (!baseIterator.resetSupported())
|
||||
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||
|
||||
|
||||
this.backedIterator = baseIterator;
|
||||
this.totalExamples = totalBatches;
|
||||
this.ratio = 0.0;
|
||||
this.ratios = ratios;
|
||||
this.numTrain = 0; //(long) (totalExamples * ratio);
|
||||
this.numTest = 0; //totalExamples - numTrain;
|
||||
this.numArbitrarySets = ratios.length;
|
||||
|
||||
this.splits = new int[this.ratios.length];
|
||||
for (int i = 0; i < this.splits.length; ++i) {
|
||||
this.splits[i] = (int)(totalExamples * ratios[i]);
|
||||
}
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, int[] splits) {
|
||||
|
||||
/*if (!(simpleRatio > 0.0 && simpleRatio < 1.0))
|
||||
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");*/
|
||||
|
||||
int totalBatches = 0;
|
||||
for (val v:splits)
|
||||
totalBatches += v;
|
||||
|
||||
if (totalBatches < 0)
|
||||
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||
|
||||
if (!baseIterator.resetSupported())
|
||||
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||
|
||||
|
||||
this.backedIterator = baseIterator;
|
||||
this.totalExamples = totalBatches;
|
||||
this.ratio = 0.0;
|
||||
this.ratios = null;
|
||||
|
||||
this.numTrain = 0; //(long) (totalExamples * ratio);
|
||||
this.numTest = 0; //totalExamples - numTrain;
|
||||
this.splits = splits;
|
||||
this.numArbitrarySets = splits.length;
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public List<DataSetIterator> getIterators() {
|
||||
List<DataSetIterator> retVal = new ArrayList<>();
|
||||
int partN = 0;
|
||||
int bottom = 0;
|
||||
for (final int split : splits) {
|
||||
ScrollableDataSetIterator partIterator =
|
||||
new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain,
|
||||
new int[]{bottom,split});
|
||||
bottom += split;
|
||||
retVal.add(partIterator);
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method returns train iterator instance
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
public DataSetIterator getTrainIterator() {
|
||||
return new DataSetIterator() {
|
||||
@Override
|
||||
|
@ -184,6 +268,7 @@ public class DataSetIteratorSplitter {
|
|||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
public DataSetIterator getTestIterator() {
|
||||
return new DataSetIterator() {
|
||||
@Override
|
||||
|
|
|
@ -21,9 +21,12 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
|
@ -43,6 +46,9 @@ public class MultiDataSetIteratorSplitter {
|
|||
protected final double ratio;
|
||||
protected final long numTrain;
|
||||
protected final long numTest;
|
||||
protected final double[] ratios;
|
||||
protected final long numArbitrarySets;
|
||||
protected final int[] splits;
|
||||
|
||||
protected AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
|
@ -71,15 +77,87 @@ public class MultiDataSetIteratorSplitter {
|
|||
this.ratio = ratio;
|
||||
this.numTrain = (long) (totalExamples * ratio);
|
||||
this.numTest = totalExamples - numTrain;
|
||||
this.ratios = null;
|
||||
this.numArbitrarySets = 0;
|
||||
this.splits = null;
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
|
||||
for (double ratio : ratios) {
|
||||
if (!(ratio > 0.0 && ratio < 1.0))
|
||||
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
|
||||
}
|
||||
|
||||
if (totalBatches < 0)
|
||||
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||
|
||||
if (!baseIterator.resetSupported())
|
||||
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||
|
||||
|
||||
this.backedIterator = baseIterator;
|
||||
this.totalExamples = totalBatches;
|
||||
this.ratio = 0.0;
|
||||
this.numTrain = (long) (totalExamples * ratio);
|
||||
this.numTest = totalExamples - numTrain;
|
||||
this.ratios = null;
|
||||
this.numArbitrarySets = ratios.length;
|
||||
|
||||
this.splits = new int[this.ratios.length];
|
||||
for (int i = 0; i < this.splits.length; ++i) {
|
||||
this.splits[i] = (int)(totalExamples * ratios[i]);
|
||||
}
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) {
|
||||
|
||||
int totalBatches = 0;
|
||||
for (val v:splits)
|
||||
totalBatches += v;
|
||||
|
||||
if (totalBatches < 0)
|
||||
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||
|
||||
if (!baseIterator.resetSupported())
|
||||
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||
|
||||
|
||||
this.backedIterator = baseIterator;
|
||||
this.totalExamples = totalBatches;
|
||||
this.ratio = 0.0;
|
||||
this.numTrain = (long) (totalExamples * ratio);
|
||||
this.numTest = totalExamples - numTrain;
|
||||
this.ratios = null;
|
||||
this.numArbitrarySets = splits.length;
|
||||
this.splits = splits;
|
||||
|
||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||
}
|
||||
|
||||
public List<MultiDataSetIterator> getIterators() {
|
||||
List<MultiDataSetIterator> retVal = new ArrayList<>();
|
||||
int partN = 0;
|
||||
int bottom = 0;
|
||||
for (final int split : splits) {
|
||||
ScrollableMultiDataSetIterator partIterator =
|
||||
new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain,
|
||||
new int[]{bottom,split});
|
||||
bottom += split;
|
||||
retVal.add(partIterator);
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns train iterator instance
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
public MultiDataSetIterator getTrainIterator() {
|
||||
return new MultiDataSetIterator() {
|
||||
@Override
|
||||
|
@ -162,6 +240,7 @@ public class MultiDataSetIteratorSplitter {
|
|||
*
|
||||
* @return
|
||||
*/
|
||||
@Deprecated
|
||||
public MultiDataSetIterator getTestIterator() {
|
||||
return new MultiDataSetIterator() {
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
package org.deeplearning4j.datasets.iterator;
|
||||
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
public class ScrollableDataSetIterator implements DataSetIterator {
|
||||
private int thisPart = 0;
|
||||
private int top = 0;
|
||||
private int bottom = 0;
|
||||
protected DataSetIterator backedIterator;
|
||||
protected AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||
protected DataSet firstTrain = null;
|
||||
protected MultiDataSet firstMultiTrain = null;
|
||||
private double ratio;
|
||||
private long totalExamples;
|
||||
private long itemsPerPart;
|
||||
private long current;
|
||||
|
||||
|
||||
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
|
||||
AtomicBoolean resetPending, DataSet firstTrain, double ratio,
|
||||
int totalExamples) {
|
||||
this.thisPart = num;
|
||||
this.backedIterator = backedIterator;
|
||||
this.counter = counter;
|
||||
this.resetPending = resetPending;
|
||||
this.firstTrain = firstTrain;
|
||||
this.ratio = ratio;
|
||||
this.totalExamples = totalExamples;
|
||||
this.itemsPerPart = (long)(totalExamples * ratio);
|
||||
this.current = 0;
|
||||
}
|
||||
|
||||
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
|
||||
AtomicBoolean resetPending, DataSet firstTrain,
|
||||
int[] itemsPerPart) {
|
||||
this.thisPart = num;
|
||||
this.bottom = itemsPerPart[0];
|
||||
this.top = bottom + itemsPerPart[1];
|
||||
this.itemsPerPart = top;
|
||||
|
||||
this.backedIterator = backedIterator;
|
||||
this.counter = counter;
|
||||
//this.resetPending = resetPending;
|
||||
this.firstTrain = firstTrain;
|
||||
//this.totalExamples = totalExamples;
|
||||
this.current = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSet next(int i) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> getLabels() {
|
||||
return backedIterator.getLabels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int inputColumns() {
|
||||
return backedIterator.inputColumns();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int totalOutcomes() {
|
||||
return backedIterator.totalOutcomes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported() {
|
||||
return backedIterator.resetSupported();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean asyncSupported() {
|
||||
return backedIterator.asyncSupported();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
resetPending.set(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int batch() {
|
||||
return backedIterator.batch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
|
||||
backedIterator.setPreProcessor(dataSetPreProcessor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSetPreProcessor getPreProcessor() {
|
||||
|
||||
return backedIterator.getPreProcessor();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (resetPending.get()) {
|
||||
if (resetSupported()) {
|
||||
backedIterator.reset();
|
||||
counter.set(0);
|
||||
current = 0;
|
||||
resetPending.set(false);
|
||||
} else
|
||||
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
|
||||
}
|
||||
|
||||
boolean state = false;
|
||||
if (current >= top)
|
||||
return false;
|
||||
state = backedIterator.hasNext();
|
||||
if (!state)
|
||||
return false;
|
||||
if (state && counter.get() < itemsPerPart)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSet next() {
|
||||
counter.incrementAndGet();
|
||||
if ((current == 0) && (bottom != 0)) {
|
||||
backedIterator.reset();
|
||||
long cnt = current;
|
||||
for (; cnt < bottom; ++cnt) {
|
||||
if (backedIterator.hasNext())
|
||||
backedIterator.next();
|
||||
}
|
||||
current = cnt+1;
|
||||
}
|
||||
else current++;
|
||||
val p = backedIterator.next();
|
||||
return p;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package org.deeplearning4j.datasets.iterator;
|
||||
|
||||
import lombok.val;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
import javax.naming.OperationNotSupportedException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
|
||||
private int thisPart = 0;
|
||||
private int top = 0;
|
||||
private int bottom = 0;
|
||||
protected MultiDataSetIterator backedIterator;
|
||||
protected AtomicLong counter = new AtomicLong(0);
|
||||
|
||||
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||
protected DataSet firstTrain = null;
|
||||
protected MultiDataSet firstMultiTrain = null;
|
||||
private double ratio;
|
||||
private long totalExamples;
|
||||
private long itemsPerPart;
|
||||
private long current;
|
||||
|
||||
public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter,
|
||||
MultiDataSet firstTrain, int[] itemsPerPart) {
|
||||
this.thisPart = num;
|
||||
this.bottom = itemsPerPart[0];
|
||||
this.top = bottom + itemsPerPart[1];
|
||||
this.itemsPerPart = top;
|
||||
|
||||
this.counter = counter;
|
||||
//this.resetPending = resetPending;
|
||||
this.firstTrain = null;
|
||||
this.firstMultiTrain = firstTrain;
|
||||
//this.totalExamples = totalExamples;
|
||||
this.current = 0;
|
||||
this.backedIterator = backedIterator;
|
||||
this.resetPending = resetPending;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean resetSupported() {
|
||||
return backedIterator.resetSupported();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean asyncSupported() {
|
||||
return backedIterator.asyncSupported();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
resetPending.set(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPreProcessor(MultiDataSetPreProcessor dataSetPreProcessor) {
|
||||
backedIterator.setPreProcessor(dataSetPreProcessor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSetPreProcessor getPreProcessor() {
|
||||
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (resetPending.get()) {
|
||||
if (resetSupported()) {
|
||||
backedIterator.reset();
|
||||
counter.set(0);
|
||||
current = 0;
|
||||
resetPending.set(false);
|
||||
} else
|
||||
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
|
||||
}
|
||||
|
||||
boolean state = false;
|
||||
if (current >= top)
|
||||
return false;
|
||||
state = backedIterator.hasNext();
|
||||
if (!state)
|
||||
return false;
|
||||
if (state && counter.get() < itemsPerPart)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSet next() {
|
||||
counter.incrementAndGet();
|
||||
if ((current == 0) && (bottom != 0)) {
|
||||
backedIterator.reset();
|
||||
long cnt = current;
|
||||
for (; cnt < bottom; ++cnt) {
|
||||
if (backedIterator.hasNext())
|
||||
backedIterator.next();
|
||||
}
|
||||
current = cnt+1;
|
||||
}
|
||||
else current++;
|
||||
val p = backedIterator.next();
|
||||
return p;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSet next(int i) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
|
@ -47,6 +47,8 @@ import static org.bytedeco.hdf5.global.hdf5.*;
|
|||
@Slf4j
|
||||
public class Hdf5Archive implements Closeable {
|
||||
|
||||
public static final int MAX_BUFFER_SIZE_BYTES = (int)Math.pow(2, 28); //256 MB
|
||||
|
||||
/**
|
||||
* HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently
|
||||
* in multiple threads. This object is used for locking read etc activity using synchronized blocks
|
||||
|
@ -338,7 +340,7 @@ public class Hdf5Archive implements Closeable {
|
|||
private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException {
|
||||
synchronized (Hdf5Archive.LOCK_OBJECT) {
|
||||
VarLenType vl = attribute.getVarLenType();
|
||||
int bufferSizeMult = 1;
|
||||
int currBufferLength = 2048;
|
||||
String s;
|
||||
/* TODO: find a less hacky way to do this.
|
||||
* Reading variable length strings (from attributes) is a giant
|
||||
|
@ -349,8 +351,8 @@ public class Hdf5Archive implements Closeable {
|
|||
* buffer and repeat.
|
||||
*/
|
||||
while (true) {
|
||||
byte[] attrBuffer = new byte[bufferSizeMult * 2000];
|
||||
BytePointer attrPointer = new BytePointer(attrBuffer);
|
||||
byte[] attrBuffer = new byte[currBufferLength];
|
||||
BytePointer attrPointer = new BytePointer(currBufferLength);
|
||||
attribute.read(vl, attrPointer);
|
||||
attrPointer.get(attrBuffer);
|
||||
s = new String(attrBuffer);
|
||||
|
@ -362,9 +364,11 @@ public class Hdf5Archive implements Closeable {
|
|||
} catch (IOException e) {
|
||||
//OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer
|
||||
}
|
||||
bufferSizeMult *= 2;
|
||||
if (bufferSizeMult > 1024) {
|
||||
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute");
|
||||
|
||||
if(currBufferLength == MAX_BUFFER_SIZE_BYTES){
|
||||
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute: size exceeds " + currBufferLength + " bytes");
|
||||
} else {
|
||||
currBufferLength = (int)Math.min(MAX_BUFFER_SIZE_BYTES, currBufferLength * 4L);
|
||||
}
|
||||
}
|
||||
vl.deallocate();
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
|
||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
||||
import org.deeplearning4j.clustering.cluster.ClusterUtils;
|
||||
|
@ -62,12 +63,13 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
private ClusterSet clusterSet;
|
||||
private List<Point> initialPoints;
|
||||
private transient ExecutorService exec;
|
||||
private boolean useKmeansPlusPlus;
|
||||
|
||||
|
||||
|
||||
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy) {
|
||||
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||
this.clusteringStrategy = clusteringStrategy;
|
||||
this.exec = MultiThreadUtils.newExecutorService();
|
||||
this.useKmeansPlusPlus = useKmeansPlusPlus;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -75,8 +77,8 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
* @param clusteringStrategy
|
||||
* @return
|
||||
*/
|
||||
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy) {
|
||||
return new BaseClusteringAlgorithm(clusteringStrategy);
|
||||
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -86,7 +88,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
*/
|
||||
public ClusterSet applyTo(List<Point> points) {
|
||||
resetState(points);
|
||||
initClusters();
|
||||
initClusters(useKmeansPlusPlus);
|
||||
iterations();
|
||||
return clusterSet;
|
||||
}
|
||||
|
@ -130,7 +132,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
* Initialize the
|
||||
* cluster centers at random
|
||||
*/
|
||||
protected void initClusters() {
|
||||
protected void initClusters(boolean kMeansPlusPlus) {
|
||||
log.info("Generating initial clusters");
|
||||
List<Point> points = new ArrayList<>(initialPoints);
|
||||
|
||||
|
@ -152,7 +154,10 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
|
||||
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
|
||||
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
|
||||
double r = random.nextFloat() * dxs.maxNumber().doubleValue();
|
||||
double summed = Nd4j.sum(dxs).getDouble(0);
|
||||
double r = kMeansPlusPlus ? random.nextDouble() * summed:
|
||||
random.nextFloat() * dxs.maxNumber().doubleValue();
|
||||
|
||||
for (int i = 0; i < dxs.length(); i++) {
|
||||
double distance = dxs.getDouble(i);
|
||||
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
|
||||
|
@ -170,6 +175,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
|||
new IterationInfo(currentIteration, initialClusterSetInfo));
|
||||
}
|
||||
|
||||
|
||||
protected void applyClusteringStrategy() {
|
||||
if (!isStrategyApplicableNow())
|
||||
return;
|
||||
|
|
|
@ -79,8 +79,8 @@ public class ClusterUtils {
|
|||
int nClusters = clusterSet.getClusterCount();
|
||||
for (int i = 0; i < nClusters; i++) {
|
||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
||||
refreshClusterCenter(cluster, clusterInfo);
|
||||
|
@ -88,10 +88,10 @@ public class ClusterUtils {
|
|||
} catch (Throwable t) {
|
||||
log.warn("Error refreshing cluster centers", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
// }
|
||||
//});
|
||||
}
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
}
|
||||
|
||||
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
|
||||
|
@ -146,28 +146,29 @@ public class ClusterUtils {
|
|||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
final int i2 = i;
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
Point point = points.get(i2);
|
||||
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
|
||||
: Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||
dxs.putScalar(i2, clusterSet.isInverse() ? dist : dist);
|
||||
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error computing squared distance from nearest cluster", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
// }
|
||||
//});
|
||||
|
||||
}
|
||||
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
double previousMinDistance = previousDxs.getDouble(i);
|
||||
if (clusterSet.isInverse()) {
|
||||
if (dxs.getDouble(i) < previousMinDistance)
|
||||
if (dxs.getDouble(i) < previousMinDistance) {
|
||||
|
||||
dxs.putScalar(i, previousMinDistance);
|
||||
}
|
||||
} else if (dxs.getDouble(i) > previousMinDistance)
|
||||
dxs.putScalar(i, previousMinDistance);
|
||||
}
|
||||
|
@ -175,6 +176,23 @@ public class ClusterUtils {
|
|||
return dxs;
|
||||
}
|
||||
|
||||
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
|
||||
final List<Point> points, INDArray previousDxs) {
|
||||
final int pointsCount = points.size();
|
||||
final INDArray dxs = Nd4j.create(pointsCount);
|
||||
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
||||
|
||||
Double sum = new Double(0);
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
|
||||
Point point = points.get(i);
|
||||
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||
sum += dist;
|
||||
dxs.putScalar(i, sum);
|
||||
}
|
||||
|
||||
return dxs;
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
|
@ -194,27 +212,27 @@ public class ClusterUtils {
|
|||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (int i = 0; i < clusterCount; i++) {
|
||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
info.getClustersInfos().put(cluster.getId(),
|
||||
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error computing cluster set info", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
//}
|
||||
//});
|
||||
}
|
||||
|
||||
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
|
||||
tasks = new ArrayList<>();
|
||||
//tasks = new ArrayList<>();
|
||||
for (int i = 0; i < clusterCount; i++) {
|
||||
final int clusterIdx = i;
|
||||
final Cluster fromCluster = clusterSet.getClusters().get(i);
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
//tasks.add(new Runnable() {
|
||||
//public void run() {
|
||||
try {
|
||||
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
|
||||
Cluster toCluster = clusterSet.getClusters().get(k);
|
||||
|
@ -230,12 +248,12 @@ public class ClusterUtils {
|
|||
} catch (Throwable t) {
|
||||
log.warn("Error computing distances", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
// }
|
||||
//});
|
||||
|
||||
}
|
||||
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
|
|
@ -37,8 +37,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
|||
*
|
||||
* @param clusteringStrategy
|
||||
*/
|
||||
protected KMeansClustering(ClusteringStrategy clusteringStrategy) {
|
||||
super(clusteringStrategy);
|
||||
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
|
||||
super(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -50,11 +50,11 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
|||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
|
||||
boolean inverse) {
|
||||
boolean inverse, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy =
|
||||
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
|
||||
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
|
||||
return new KMeansClustering(clusteringStrategy);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -66,10 +66,10 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
|||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||
boolean inverse, boolean allowEmptyClusters) {
|
||||
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
|
||||
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||
return new KMeansClustering(clusteringStrategy);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
|
||||
|
@ -81,8 +81,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
|||
* @param distanceFunction the distance function to use for grouping
|
||||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction) {
|
||||
return setup(clusterCount, maxIterationCount, distanceFunction, false);
|
||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
|
||||
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -94,17 +94,17 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
|||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||
boolean allowEmptyClusters) {
|
||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||
return new KMeansClustering(clusteringStrategy);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
|
||||
boolean allowEmptyClusters) {
|
||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
|
||||
return new KMeansClustering(clusteringStrategy);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.clustering.kmeans;
|
||||
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.time.StopWatch;
|
||||
import org.deeplearning4j.clustering.BaseDL4JTest;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
|
@ -28,22 +29,25 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Created by agibsonccc on 7/2/17.
|
||||
*/
|
||||
public class KMeansTest extends BaseDL4JTest {
|
||||
|
||||
private boolean[] useKMeansPlusPlus = {true, false};
|
||||
|
||||
@Test
|
||||
public void testKMeans() {
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN);
|
||||
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||
System.out.println(pointClassification);
|
||||
for (boolean mode : useKMeansPlusPlus) {
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode);
|
||||
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||
System.out.println(pointClassification);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -51,20 +55,22 @@ public class KMeansTest extends BaseDL4JTest {
|
|||
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
int numClusters = 5;
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
|
||||
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||
for (boolean mode : useKMeansPlusPlus) {
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
|
||||
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||
|
||||
|
||||
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN);
|
||||
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
|
||||
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
|
||||
System.out.println("Cosine " + pointClassification);
|
||||
System.out.println("Euclidean " + pointClassificationEuclidean);
|
||||
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
|
||||
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
|
||||
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
|
||||
System.out.println("Cosine " + pointClassification);
|
||||
System.out.println("Euclidean " + pointClassificationEuclidean);
|
||||
|
||||
assertEquals(pointClassification.getCluster().getPoints().get(0),
|
||||
pointClassificationEuclidean.getCluster().getPoints().get(0));
|
||||
assertEquals(pointClassification.getCluster().getPoints().get(0),
|
||||
pointClassificationEuclidean.getCluster().getPoints().get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@Ignore
|
||||
|
@ -73,22 +79,24 @@ public class KMeansTest extends BaseDL4JTest {
|
|||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
int numClusters = 20;
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
|
||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000*300, 5000*300).reshape(5000,300 ));
|
||||
for (boolean mode : useKMeansPlusPlus) {
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
|
||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300));
|
||||
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -97,41 +105,43 @@ public class KMeansTest extends BaseDL4JTest {
|
|||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
int numClusters = 20;
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false);
|
||||
for (boolean mode : useKMeansPlusPlus) {
|
||||
StopWatch watch = new StopWatch();
|
||||
watch.start();
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode);
|
||||
|
||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
|
||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
|
||||
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode);
|
||||
|
||||
points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
|
||||
|
||||
clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false);
|
||||
|
||||
points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
|
||||
|
||||
clusterSet = kMeansClustering.applyTo(points);
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for clustering : " + watch);
|
||||
|
||||
watch.reset();
|
||||
watch.start();
|
||||
for (Point p : points) {
|
||||
PointClassification pointClassification = clusterSet.classifyPoint(p);
|
||||
}
|
||||
watch.stop();
|
||||
System.out.println("Elapsed for search: " + watch);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -141,45 +151,47 @@ public class KMeansTest extends BaseDL4JTest {
|
|||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
int numClusters = 3;
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, true);
|
||||
double[] data = new double[]{
|
||||
15, 16,
|
||||
16, 18.5,
|
||||
17, 20.2,
|
||||
16.4, 17.12,
|
||||
17.23, 18.12,
|
||||
43, 43,
|
||||
44.43, 45.212,
|
||||
45.8, 54.23,
|
||||
46.313, 43.123,
|
||||
50.21, 46.3,
|
||||
99, 99.22,
|
||||
100.32, 98.123,
|
||||
100.32, 97.423,
|
||||
102, 93.23,
|
||||
102.23, 94.23
|
||||
};
|
||||
List<Point> points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2));
|
||||
for (boolean mode : useKMeansPlusPlus) {
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
|
||||
double[] data = new double[]{
|
||||
15, 16,
|
||||
16, 18.5,
|
||||
17, 20.2,
|
||||
16.4, 17.12,
|
||||
17.23, 18.12,
|
||||
43, 43,
|
||||
44.43, 45.212,
|
||||
45.8, 54.23,
|
||||
46.313, 43.123,
|
||||
50.21, 46.3,
|
||||
99, 99.22,
|
||||
100.32, 98.123,
|
||||
100.32, 97.423,
|
||||
102, 93.23,
|
||||
102.23, 94.23
|
||||
};
|
||||
List<Point> points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2));
|
||||
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
|
||||
|
||||
INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850});
|
||||
INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500});
|
||||
INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990});
|
||||
INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850});
|
||||
INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500});
|
||||
INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990});
|
||||
|
||||
/*List<Cluster> clusters = clusterSet.getClusters();
|
||||
assertEquals(row0, clusters.get(0).getCenter().getArray());
|
||||
assertEquals(row1, clusters.get(1).getCenter().getArray());
|
||||
assertEquals(row2, clusters.get(2).getCenter().getArray());*/
|
||||
|
||||
PointClassification pointClassification = null;
|
||||
for (Point p : points) {
|
||||
pointClassification = clusterSet.classifyPoint(p);
|
||||
System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray());
|
||||
List<Cluster> clusters = clusterSet.getClusters();
|
||||
for (int i = 0; i < clusters.size(); ++i)
|
||||
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
|
||||
PointClassification pointClassification = null;
|
||||
for (Point p : points) {
|
||||
pointClassification = clusterSet.classifyPoint(p);
|
||||
System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray());
|
||||
List<Cluster> clusters = clusterSet.getClusters();
|
||||
for (int i = 0; i < clusters.size(); ++i)
|
||||
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
|
||||
}
|
||||
}
|
||||
/*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}),
|
||||
pointClassification.getCluster().getCenter().getArray());*/
|
||||
|
@ -233,4 +245,39 @@ public class KMeansTest extends BaseDL4JTest {
|
|||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInitClusters() {
|
||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
{
|
||||
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true);
|
||||
|
||||
double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8},
|
||||
{8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8},
|
||||
{1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8},
|
||||
{2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8},
|
||||
{3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8},
|
||||
{4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8},
|
||||
{4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8},
|
||||
{5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8},
|
||||
{6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}};
|
||||
INDArray data = Nd4j.createFromArray(dataArray);
|
||||
List<Point> points = Point.toPoints(data);
|
||||
|
||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||
|
||||
double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8};
|
||||
double[] centroid2 = {5.14e8, 5.41e8, 5.68e8, 5.95e8};
|
||||
double[] centroid3 = {1.63e8, 1.9e8, 2.17e8, 2.44e8};
|
||||
double[] centroid4 = {6.76e8, 7.03e8, 7.3e8, 7.57e8};
|
||||
double[] centroid5 = {4.06e8, 4.33e8, 4.6e8, 4.87e8};
|
||||
|
||||
assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||
assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||
assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||
assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||
assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@ import org.apache.commons.io.FileUtils;
|
|||
import org.apache.commons.lang.ArrayUtils;
|
||||
import org.apache.commons.lang3.RandomUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
|
@ -857,4 +859,34 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBackwardsCompatibleWord2Vec() {
|
||||
File model_v3 = Resources.asFile("deeplearning4j-nlp/model_beta3.zip");
|
||||
File model_v4 = Resources.asFile("deeplearning4j-nlp/model_beta4.zip");
|
||||
Word2Vec word2Vec1 = WordVectorSerializer.readWord2VecModel(model_v3, true);
|
||||
Word2Vec word2Vec2 = WordVectorSerializer.readWord2VecModel(model_v4, true);
|
||||
try {
|
||||
assertEquals(word2Vec1.toJson(), word2Vec2.toJson());
|
||||
} catch (Exception e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBackwardsCompatibleSequenceVectors() {
|
||||
File model_v3 = Resources.asFile("deeplearning4j-nlp/seqv_beta3.csv");
|
||||
File model_v4 = Resources.asFile("deeplearning4j-nlp/seqv_beta4.csv");
|
||||
try {
|
||||
SequenceVectors vectors1 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v3);
|
||||
SequenceVectors vectors2 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v4);
|
||||
|
||||
assertEquals(vectors1.vocab().numWords(), vectors2.vocab().numWords());
|
||||
for (int i = 0; i < vectors1.vocab().numWords(); ++i) {
|
||||
assertEquals(vectors1.vocab().words().toArray()[i], vectors2.vocab().words().toArray()[i]);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -249,7 +249,7 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
} else {
|
||||
throw new RuntimeException();
|
||||
}
|
||||
l[0] = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, numClasses);
|
||||
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
|
||||
for( int i=0; i<mb; i++ ){
|
||||
l[0].putScalar(i, classLabels[i], 1.0);
|
||||
}
|
||||
|
@ -277,9 +277,9 @@ public class BertIterator implements MultiDataSetIterator {
|
|||
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
|
||||
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
|
||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
|
||||
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, vocabSize, outLength);
|
||||
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
|
||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
|
||||
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), outLength, mbPadded, vocabSize);
|
||||
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
|
||||
} else {
|
||||
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
|
||||
}
|
||||
|
|
|
@ -201,7 +201,7 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
|||
List<String> tokens = new ArrayList<>();
|
||||
while (t.hasMoreTokens()) {
|
||||
String token = t.nextToken();
|
||||
if (!wordVectors.hasWord(token)) {
|
||||
if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) {
|
||||
switch (unknownWordHandling) {
|
||||
case RemoveWord:
|
||||
continue;
|
||||
|
|
|
@ -1312,10 +1312,12 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
|||
int rest = batchSequences.size() % batchSize;
|
||||
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
|
||||
for (int j = 0; j < chunks; ++j) {
|
||||
if (elementsLearningAlgorithm instanceof SkipGram)
|
||||
((SkipGram)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||
else if (elementsLearningAlgorithm instanceof CBOW)
|
||||
((CBOW)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||
if (trainElementsVectors) {
|
||||
if (elementsLearningAlgorithm instanceof SkipGram)
|
||||
((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||
else if (elementsLearningAlgorithm instanceof CBOW)
|
||||
((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||
}
|
||||
|
||||
if (trainSequenceVectors) {
|
||||
if (sequenceLearningAlgorithm instanceof DBOW)
|
||||
|
|
|
@ -32,7 +32,7 @@ import java.io.Serializable;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = VocabWord.class)
|
||||
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
|
||||
setterVisibility = JsonAutoDetect.Visibility.NONE)
|
||||
public class VocabWord extends SequenceElement implements Serializable {
|
||||
|
|
|
@ -224,6 +224,7 @@ public class TestBertIterator extends BaseDL4JTest {
|
|||
|
||||
@Test(timeout = 20000L)
|
||||
public void testMinibatchPadding() throws Exception {
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
String toTokenize1 = "I saw a girl with a telescope.";
|
||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.api;
|
||||
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
|
@ -73,4 +74,6 @@ public interface TrainingConfig {
|
|||
*/
|
||||
double getGradientNormalizationThreshold();
|
||||
|
||||
void setDataType(DataType dataType);
|
||||
|
||||
}
|
||||
|
|
|
@ -93,4 +93,9 @@ public abstract class GraphVertex implements Cloneable, Serializable {
|
|||
*/
|
||||
public abstract MemoryReport getMemoryReport(InputType... inputTypes);
|
||||
|
||||
|
||||
public void setDataType(DataType dataType) {
|
||||
//No-op for most layers
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -146,4 +146,9 @@ public class LayerVertex extends GraphVertex {
|
|||
//TODO preprocessor memory
|
||||
return layerConf.getLayer().getMemoryReport(it);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDataType(DataType dataType){
|
||||
layerConf.getLayer().setDataType(dataType);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -223,6 +223,11 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
|
|||
"Not supported: all layers with parameters should override this method");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDataType(DataType dataType) {
|
||||
//No-op for most layers
|
||||
}
|
||||
|
||||
/**
|
||||
* This is a report of the estimated memory consumption for the given layer
|
||||
*
|
||||
|
|
|
@ -96,7 +96,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
|||
|
||||
if (!map.containsKey(inputNum)) {
|
||||
//Lazily define extra input variable as required
|
||||
SDVariable var = sameDiff.var("var_" + inputNum, 1); //TODO is this shape safe?
|
||||
SDVariable var = sameDiff.var("var_" + inputNum, dataType, -1); //TODO is this shape safe?
|
||||
map.put(inputNum, var);
|
||||
}
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
|||
protected IUpdater biasUpdater;
|
||||
protected GradientNormalization gradientNormalization;
|
||||
protected double gradientNormalizationThreshold = Double.NaN;
|
||||
protected DataType dataType;
|
||||
|
||||
/**
|
||||
* Define the vertex
|
||||
|
@ -234,4 +235,9 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
|||
public double getGradientNormalizationThreshold() {
|
||||
return gradientNormalizationThreshold;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDataType(DataType dataType) {
|
||||
this.dataType = dataType;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.misc;
|
|||
import lombok.AllArgsConstructor;
|
||||
import org.deeplearning4j.nn.api.TrainingConfig;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.learning.config.NoOp;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
@ -63,4 +64,9 @@ public class DummyConfig implements TrainingConfig {
|
|||
public double getGradientNormalizationThreshold() {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setDataType(DataType dataType) {
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -512,6 +512,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
for(; i<topologicalOrder.length; i++ ){
|
||||
String name = indices.getIdxToName().get(i);
|
||||
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
|
||||
n.setDataType(netDtype);
|
||||
numParamsForVertex[i] = n.numParams(true);
|
||||
numParams += numParamsForVertex[i];
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
|
|||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.BaseLayer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
@ -60,10 +62,16 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
|||
assertInputSet(true);
|
||||
if (input.rank() != 3)
|
||||
throw new UnsupportedOperationException(
|
||||
"Input is not rank 3. Got input with rank " + input.rank() + " " + layerId());
|
||||
"Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
|
||||
input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId());
|
||||
if (labels == null)
|
||||
throw new IllegalStateException("Labels are not set (null)");
|
||||
|
||||
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||
|
||||
|
||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
INDArray maskReshaped;
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
|||
import org.deeplearning4j.nn.layers.BaseOutputLayer;
|
||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -57,8 +58,13 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
|
||||
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
|
||||
}
|
||||
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||
|
||||
INDArray inputTemp = input;
|
||||
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||
|
||||
Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout
|
||||
this.input = inputTemp;
|
||||
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
||||
|
|
|
@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Implementation of a SameDiff graph vertex.
|
||||
|
@ -96,12 +94,11 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
|
||||
@Override
|
||||
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
// sameDiff.clearExecutionCache();
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
config.validateInput(inputs);
|
||||
for(int i=0; i<inputs.length; i++ ){
|
||||
String name = config.getVertexParams().getInputs().get(i);
|
||||
|
@ -121,6 +118,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
}
|
||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
||||
INDArray result = out.get(outputKey);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||
}
|
||||
}
|
||||
|
@ -131,27 +132,42 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
|
||||
INDArray[] dLdIns;
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
// sameDiff.clearExecutionCache();
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
if(!sameDiff.hasGradientFunction()) {
|
||||
//Create when scoped out, to ensure any arrays are not in WS
|
||||
List<String> inputs = config.getVertexParams().getInputs();
|
||||
String[] inArr = inputs.toArray(new String[inputs.size()]);
|
||||
sameDiff.createGradFunction(inArr);
|
||||
}
|
||||
config.validateInput(inputs);
|
||||
//Set inputs
|
||||
for(int i=0; i<inputs.length; i++ ){
|
||||
String name = config.getVertexParams().getInputs().get(i);
|
||||
final String maskName = name + "_mask";
|
||||
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
|
||||
if(maskArrays != null && maskArrays[i] != null) {
|
||||
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
|
||||
}else{
|
||||
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
List<String> inputs = config.getVertexParams().getInputs();
|
||||
int i=0;
|
||||
for(String s : inputs){
|
||||
phMap.put(s, this.inputs[i++]);
|
||||
}
|
||||
if(maskArrays != null){
|
||||
for( int j=0; j<maskArrays.length; j++ ){
|
||||
String name = inputs.get(j);
|
||||
final String maskName = name + "_mask";
|
||||
if(maskArrays[j] != null) {
|
||||
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
|
||||
}
|
||||
}
|
||||
}
|
||||
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
|
||||
String epsName = fn.getGradPlaceholderName();
|
||||
phMap.put(epsName, epsilon);
|
||||
|
||||
|
||||
for(String s : paramTable.keySet() ){
|
||||
//TODO this should only be necessary, in theory, once!
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
}
|
||||
|
||||
sameDiff.execBackwards(null);
|
||||
sameDiff.execBackwards(phMap);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
|
@ -159,10 +175,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
g.gradientForVariable().put(s, dl4jGrad);
|
||||
}
|
||||
|
||||
dLdIns = new INDArray[inputs.length];
|
||||
for(int i=0; i<inputs.length; i++ ){
|
||||
String name = config.getVertexParams().getInputs().get(i);
|
||||
dLdIns[i] = sameDiff.grad(name).getArr();
|
||||
dLdIns = new INDArray[inputs.size()];
|
||||
for(int j=0; j<inputs.size(); j++ ){
|
||||
String name = inputs.get(j);
|
||||
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +187,9 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
return new Pair<>(g, dLdIns);
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -78,25 +79,32 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(false);
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
bl.validateInput(input);
|
||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
if(maskArray != null){
|
||||
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
|
||||
}else{
|
||||
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
|
||||
phMap.put(MASK_KEY, maskArray);
|
||||
}
|
||||
|
||||
for(String s : paramTable.keySet() ) {
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
}
|
||||
|
||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
||||
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
|
||||
INDArray result = out.get(outputKey);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||
}
|
||||
}
|
||||
|
@ -110,24 +118,36 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
|
||||
INDArray dLdIn;
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
// sameDiff.clearExecutionCache();
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
if(!sameDiff.hasGradientFunction()) {
|
||||
//Create when scoped out, to ensure any arrays are not in WS
|
||||
sameDiff.createGradFunction(INPUT_KEY);
|
||||
}
|
||||
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
bl.validateInput(input);
|
||||
|
||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
||||
if(maskArray != null){
|
||||
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
|
||||
}else{
|
||||
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
|
||||
}
|
||||
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
|
||||
|
||||
for(String s : paramTable.keySet() ){
|
||||
//TODO this should only be necessary, in theory, once!
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
}
|
||||
|
||||
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
phMap.put(fn.getGradPlaceholderName(), epsilon);
|
||||
if(maskArray != null){
|
||||
phMap.put(MASK_KEY, maskArray);
|
||||
}
|
||||
|
||||
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
||||
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||
for(String s : paramTable.keySet()){
|
||||
requiredGrads.add(sameDiff.grad(s).getVarName());
|
||||
}
|
||||
|
||||
sameDiff.execBackwards(phMap, requiredGrads);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
|
@ -138,6 +158,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
System.out.println(dLdIn);
|
||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
}
|
||||
|
||||
|
@ -225,8 +250,9 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
sameDiff = SameDiff.create();
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
val inputShape = input.shape().clone();
|
||||
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
|
||||
long[] inputShape = input.shape().clone();
|
||||
inputShape[0] = -1;
|
||||
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
|
||||
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
||||
Map<String, SDVariable> params = new LinkedHashMap<>();
|
||||
for (String s : paramShapes.keySet()) {
|
||||
|
@ -235,7 +261,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
params.put(s, v);
|
||||
}
|
||||
|
||||
SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(dataType, inputShape));
|
||||
long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, -1);
|
||||
SDVariable mask = sameDiff.placeHolder(MASK_KEY, dataType, maskShape);
|
||||
|
||||
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask);
|
||||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||
|
|
|
@ -87,35 +87,43 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){
|
||||
assertInputSet(false);
|
||||
|
||||
//Check where the output occors. If it's a simple loss layer (no params) this could
|
||||
//Check where the output occurs. If it's a simple loss layer (no params) this could
|
||||
// just be the input!
|
||||
if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
||||
}
|
||||
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
//TODO optimize
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
||||
if(layerConf().labelsRequired() && labels != null) {
|
||||
sameDiff.associateArrayWithVariable(labels.dup(), sameDiff.getVariable(LABELS_KEY));
|
||||
if(sameDiff == null){
|
||||
doInit();
|
||||
}
|
||||
|
||||
for(String s : paramTable.keySet() ) {
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
}
|
||||
|
||||
INDArray score = sameDiff.execAndEndResult();
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
if(!activations && layerConf().labelsRequired() && labels != null) {
|
||||
phMap.put(LABELS_KEY, labels);
|
||||
}
|
||||
|
||||
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
||||
|
||||
INDArray out = sameDiff.execSingle(phMap, s);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
if(activations) {
|
||||
INDArray result = sameDiff.getArrForVarName(layerConf().activationsVertexName());
|
||||
Preconditions.checkNotNull(result, "Activations (result) array for variable \"%s\" was " +
|
||||
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
|
||||
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
|
||||
"does not exist", layerConf().activationsVertexName());
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
||||
} else {
|
||||
return score;
|
||||
return out;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -127,23 +135,26 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " +
|
||||
"If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" +
|
||||
" to return false instead");
|
||||
|
||||
if(sameDiff == null){
|
||||
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
||||
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
||||
doInit();
|
||||
}
|
||||
|
||||
Gradient g = new DefaultGradient();
|
||||
|
||||
INDArray dLdIn;
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
INDArray castInput = input.castTo(Nd4j.defaultFloatingPointType());
|
||||
if(sameDiff == null){
|
||||
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
||||
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
||||
doInit();
|
||||
}
|
||||
if(!sameDiff.hasGradientFunction()) {
|
||||
//Create when scoped out, to ensure any arrays are not in WS
|
||||
sameDiff.createGradFunction(INPUT_KEY);
|
||||
}
|
||||
|
||||
INDArray castInput = input.castTo(dataType);
|
||||
if(castInput.isAttached())
|
||||
castInput = castInput.dup();
|
||||
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
|
||||
if(layerConf().labelsRequired()) {
|
||||
INDArray castLabels = labels.castTo(Nd4j.defaultFloatingPointType());
|
||||
INDArray castLabels = labels.castTo(dataType);
|
||||
if(castLabels.isAttached())
|
||||
castLabels = castLabels.dup();
|
||||
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
|
||||
|
@ -154,7 +165,17 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
}
|
||||
|
||||
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
|
||||
List<String> gradVarNames = new ArrayList<>();
|
||||
for(String s : paramTable.keySet()){
|
||||
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
|
||||
}
|
||||
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
phMap.put(LABELS_KEY, labels);
|
||||
|
||||
sameDiff.execBackwards(phMap, gradVarNames);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
|
@ -165,6 +186,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
}
|
||||
|
||||
|
@ -252,18 +277,20 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff = SameDiff.create();
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
val inputShape = input.shape().clone();
|
||||
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
|
||||
long[] inputShape = input.shape().clone();
|
||||
inputShape[0] = -1;
|
||||
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
|
||||
SDVariable labelVar = null;
|
||||
if(layerConf().labelsRequired()){
|
||||
long[] labelShape = labels == null ? new long[]{1} : labels.shape().clone();
|
||||
labelVar = sameDiff.var(LABELS_KEY, dataType, labelShape);
|
||||
long[] labelShape = labels == null ? new long[]{-1, -1} : labels.shape().clone();
|
||||
labelShape[0] = -1;
|
||||
labelVar = sameDiff.placeHolder(LABELS_KEY, dataType, labelShape);
|
||||
}
|
||||
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
||||
Map<String, SDVariable> params = new LinkedHashMap<>();
|
||||
for (String s : paramShapes.keySet()) {
|
||||
val ps = paramShapes.get(s);
|
||||
SDVariable v = sameDiff.var(s, ps);
|
||||
SDVariable v = sameDiff.var(s, dataType, ps);
|
||||
params.put(s, v);
|
||||
}
|
||||
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params);
|
||||
|
|
|
@ -660,6 +660,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
val nParamsPerLayer = new long[nLayers];
|
||||
for (int i = 0; i < nLayers; i++) {
|
||||
NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i);
|
||||
conf.getLayer().setDataType(netDtype);
|
||||
nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
|
||||
paramLength += nParamsPerLayer[i];
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ public class HardwareMetric implements Serializable {
|
|||
return builder.logicalProcessorCount(processor.getLogicalProcessorCount())
|
||||
.physicalProcessorCount(processor.getPhysicalProcessorCount())
|
||||
.name(name)
|
||||
.averagedCpuLoad((long) processor.getSystemCpuLoad() * 100)
|
||||
.averagedCpuLoad((long)(processor.getSystemCpuLoad() * 100))
|
||||
.ioWaitTime(iowait).gpuMetrics(gpuMetric)
|
||||
.hostName(networkParams.getHostName()).diskInfo(diskInfoMap)
|
||||
.currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable())
|
||||
|
|
|
@ -48,8 +48,6 @@ if(WIN32)
|
|||
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
if ("${LIBND4J_ALL_OPS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
|
||||
else()
|
||||
|
@ -234,21 +232,21 @@ if(CUDA_BLAS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/*.cpp ../include/execution/*.h)
|
||||
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
||||
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
|
||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/cuda/*.cu ../include/helpers/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
||||
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
||||
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
|
||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
|
||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu ../include/ops/declarable/helpers/impl/*.cpp)
|
||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
|
@ -258,26 +256,12 @@ if(CUDA_BLAS)
|
|||
else()
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
||||
|
||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
||||
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
||||
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
|
||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
|
||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu)
|
||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES})
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES})
|
||||
endif()
|
||||
|
||||
|
||||
|
@ -308,7 +292,7 @@ elseif(CPU_BLAS)
|
|||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
|
||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
|
||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
||||
|
|
|
@ -372,8 +372,8 @@ namespace nd4j {
|
|||
/**
|
||||
* if _bufferD==nullptr return _buffer, else return _bufferD
|
||||
*/
|
||||
FORCEINLINE void* specialBuffer();
|
||||
FORCEINLINE void* getSpecialBuffer() const;
|
||||
void* specialBuffer();
|
||||
void* getSpecialBuffer() const;
|
||||
|
||||
/**
|
||||
* returns device buffer if compilation is for cuda case, otherwise returns host buffer
|
||||
|
@ -429,16 +429,16 @@ namespace nd4j {
|
|||
/**
|
||||
* permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array
|
||||
*/
|
||||
NDArray* permute(const std::initializer_list<int>& dimensions) const;
|
||||
NDArray* permute(const std::vector<int>& dimensions) const;
|
||||
NDArray* permute(const int* dimensions, const int rank) const;
|
||||
NDArray permute(const std::initializer_list<int>& dimensions) const;
|
||||
NDArray permute(const std::vector<int>& dimensions) const;
|
||||
NDArray permute(const int* dimensions, const int rank) const;
|
||||
|
||||
void permute(const int* dimensions, const int rank, NDArray& target) const;
|
||||
void permute(const std::vector<int>& dimensions, NDArray& target) const;
|
||||
|
||||
NDArray* permute(const std::initializer_list<Nd4jLong>& dimensions) const;
|
||||
NDArray* permute(const std::vector<Nd4jLong>& dimensions) const;
|
||||
NDArray* permute(const Nd4jLong* dimensions, const int rank) const;
|
||||
NDArray permute(const std::initializer_list<Nd4jLong>& dimensions) const;
|
||||
NDArray permute(const std::vector<Nd4jLong>& dimensions) const;
|
||||
NDArray permute(const Nd4jLong* dimensions, const int rank) const;
|
||||
|
||||
void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const;
|
||||
void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const;
|
||||
|
@ -508,7 +508,7 @@ namespace nd4j {
|
|||
/**
|
||||
* returns new copy of this array, optionally in different order
|
||||
*/
|
||||
NDArray *dup(const char newOrder = 'a');
|
||||
NDArray *dup(const char newOrder = 'a') const;
|
||||
|
||||
/**
|
||||
* returns sum of all elements of array
|
||||
|
@ -687,7 +687,7 @@ namespace nd4j {
|
|||
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
|
||||
#if defined(__CUDABLAS__) && defined(BUILD_TESTS)
|
||||
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
|
||||
template <typename Lambda>
|
||||
FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr);
|
||||
|
||||
|
@ -790,8 +790,7 @@ namespace nd4j {
|
|||
/**
|
||||
* apply transpose operation to the copy of this array, that is this array remains unaffected
|
||||
*/
|
||||
NDArray* transpose() const;
|
||||
NDArray transp() const;
|
||||
NDArray transpose() const;
|
||||
|
||||
/**
|
||||
* perform transpose operation and store result in target, this array remains unaffected
|
||||
|
@ -915,7 +914,7 @@ namespace nd4j {
|
|||
*
|
||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||
*/
|
||||
NDArray* reshape(const char order, const std::vector<Nd4jLong>& shape) const;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const;
|
||||
|
||||
/**
|
||||
* calculate strides and set given order
|
||||
|
@ -2093,15 +2092,6 @@ Nd4jLong* NDArray::shapeInfo() {
|
|||
return _shapeInfo;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::specialBuffer() {
|
||||
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong* NDArray::specialShapeInfo() {
|
||||
if (_shapeInfoD == nullptr)
|
||||
|
@ -2110,14 +2100,6 @@ Nd4jLong* NDArray::specialShapeInfo() {
|
|||
return _shapeInfoD;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::getSpecialBuffer() const {
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong NDArray::getBufferOffset() const {
|
||||
return _offset;
|
||||
|
@ -2137,7 +2119,7 @@ Nd4jLong* NDArray::getSpecialShapeInfo() const{
|
|||
}
|
||||
|
||||
|
||||
#if defined(__CUDACC__) && defined(BUILD_TESTS)
|
||||
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
|
||||
// for CUDA we need stil stuff inline
|
||||
#include "cuda/NDArrayLambda.hpp"
|
||||
#endif
|
||||
|
|
|
@ -39,9 +39,9 @@ NDArray* NDArray::asT() const{
|
|||
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||
auto l = this->lengthOf();
|
||||
|
||||
prepareSpecialUse({result}, {this});
|
||||
NDArray::prepareSpecialUse({result}, {this});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({result}, {this});
|
||||
NDArray::registerSpecialUse({result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -583,117 +583,130 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop
|
|||
void NDArray::assign(const double value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const float value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const float16 value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const bfloat16& value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const Nd4jLong value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const int value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const int16_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const uint8_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const uint16_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const uint32_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const uint64_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const int8_t value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::assign(const bool value) {
|
||||
// just fire scalar
|
||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||
prepareSpecialUse({this}, {&temp});
|
||||
|
||||
NDArray::prepareSpecialUse({this}, {&temp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&temp});
|
||||
NDArray::registerSpecialUse({this}, {&temp});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -716,9 +729,9 @@ NDArray NDArray::varianceNumber(nd4j::variance::Ops op, bool biasCorrected) {
|
|||
|
||||
NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext());
|
||||
|
||||
prepareSpecialUse({&res}, {this});
|
||||
NDArray::prepareSpecialUse({&res}, {this});
|
||||
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected);
|
||||
registerSpecialUse({&res}, {this});
|
||||
NDArray::registerSpecialUse({&res}, {this});
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -918,9 +931,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::FloatOps op, void *extraParams) cons
|
|||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -932,9 +945,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::SameOps op, void *extraParams) const
|
|||
|
||||
NDArray result(dataType(), getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -947,9 +960,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::BoolOps op, void *extraParams) const
|
|||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -962,9 +975,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
|
|||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
|
||||
NDArray result(shape, true, this->getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -976,9 +989,9 @@ void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *ext
|
|||
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
|
||||
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
NDArray::prepareSpecialUse({&target}, {this});
|
||||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||
registerSpecialUse({&target}, {this});
|
||||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -989,9 +1002,9 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
|
|||
if(!target.isScalar() || target.dataType() != dataType())
|
||||
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
NDArray::prepareSpecialUse({&target}, {this});
|
||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||
registerSpecialUse({&target}, {this});
|
||||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1002,9 +1015,9 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
|
|||
if(!target.isScalar() || target.dataType() != DataType::BOOL)
|
||||
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
NDArray::prepareSpecialUse({&target}, {this});
|
||||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||
registerSpecialUse({&target}, {this});
|
||||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1015,9 +1028,9 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
|
|||
if(!target.isScalar() || target.dataType() != DataType::INT64)
|
||||
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
|
||||
|
||||
prepareSpecialUse({&target}, {this});
|
||||
NDArray::prepareSpecialUse({&target}, {this});
|
||||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||
registerSpecialUse({&target}, {this});
|
||||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1027,9 +1040,9 @@ NDArray NDArray::indexReduceNumber(nd4j::indexreduce::Ops op, ExtraArguments *ex
|
|||
|
||||
auto res = NDArrayFactory::create<Nd4jLong>(0);
|
||||
|
||||
NDArray::prepareSpecialUse({&res}, {this});
|
||||
NDArray::NDArray::prepareSpecialUse({&res}, {this});
|
||||
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo());
|
||||
NDArray::registerSpecialUse({&res}, {this});
|
||||
NDArray::NDArray::registerSpecialUse({&res}, {this});
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -1240,17 +1253,10 @@ BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4j
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
|
||||
NDArray* NDArray::transpose() const {
|
||||
auto newArr = new NDArray(getBuffer(), getSpecialBuffer(), getShapeInfo(), getContext(), false, false);
|
||||
newArr->transposei();
|
||||
|
||||
return newArr;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::transp() const {
|
||||
NDArray newArr(getBuffer(), getShapeInfo(), getContext(), false);
|
||||
NDArray NDArray::transpose() const {
|
||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||
newArr.transposei();
|
||||
|
||||
return newArr;
|
||||
}
|
||||
|
||||
|
@ -1360,10 +1366,10 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||
NDArray* NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
|
||||
|
||||
auto newArr = new NDArray(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext());
|
||||
newArr->reshapei(order, shape);
|
||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||
newArr.reshapei(order, shape);
|
||||
|
||||
return newArr;
|
||||
}
|
||||
|
@ -1420,43 +1426,43 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const int* dimensions, const int rank) const {
|
||||
NDArray NDArray::permute(const int* dimensions, const int rank) const {
|
||||
|
||||
// evaluate shapeInfo for output (permuted) array ret
|
||||
auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace());
|
||||
auto ret = new NDArray(_buffer, ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
|
||||
ret->_isView = true;
|
||||
NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
|
||||
ret._isView = true;
|
||||
return ret;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
|
||||
NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
|
||||
int tempDims[MAX_RANK];
|
||||
shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank);
|
||||
return permute(tempDims, rank);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const std::vector<int>& dimensions) const {
|
||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const {
|
||||
auto data = dimensions.data();
|
||||
auto size = dimensions.size();
|
||||
return permute(data, size);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
|
||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
|
||||
return permute(dimensions.data(), dimensions.size());
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const std::initializer_list<int>& dimensions) const {
|
||||
NDArray NDArray::permute(const std::initializer_list<int>& dimensions) const {
|
||||
std::vector<int> vec(dimensions);
|
||||
return permute(vec);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
|
||||
NDArray NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
|
||||
std::vector<Nd4jLong> vec(dimensions);
|
||||
return permute(vec);
|
||||
}
|
||||
|
@ -1528,10 +1534,9 @@ bool NDArray::isUnitary() {
|
|||
throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !");
|
||||
|
||||
auto tr = this->transpose();
|
||||
auto trMul = MmulHelper::mmul(this, tr, nullptr, 1.f, 0.f);
|
||||
auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f);
|
||||
|
||||
bool result = trMul->isIdentityMatrix();
|
||||
delete tr;
|
||||
delete trMul;
|
||||
|
||||
return result;
|
||||
|
@ -1777,11 +1782,11 @@ NDArray NDArray::operator*(const T& scalar) const {
|
|||
|
||||
auto tmp = NDArrayFactory::create(dataType(), scalar, getContext());
|
||||
NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext());
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
||||
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
||||
|
||||
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
||||
|
||||
return result;
|
||||
}
|
||||
template NDArray NDArray::operator*(const double& scalar) const;
|
||||
|
@ -1811,6 +1816,7 @@ NDArray NDArray::operator/(const T& scalar) const {
|
|||
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
||||
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
||||
|
||||
return result;
|
||||
}
|
||||
template NDArray NDArray::operator/(const double& scalar) const;
|
||||
|
@ -2050,14 +2056,14 @@ void NDArray::operator+=(const NDArray& other) {
|
|||
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
||||
|
||||
if (!this->isScalar() && other.isScalar()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else{
|
||||
Nd4jLong *bShape = nullptr;
|
||||
|
@ -2084,14 +2090,14 @@ void NDArray::operator-=(const NDArray& other) {
|
|||
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
||||
|
||||
if (!this->isScalar() && other.isScalar()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else{
|
||||
Nd4jLong *bShape = nullptr;
|
||||
|
@ -2117,14 +2123,14 @@ void NDArray::operator*=(const NDArray& other) {
|
|||
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
||||
|
||||
if (!this->isScalar() && other.isScalar()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else{
|
||||
Nd4jLong *bShape = nullptr;
|
||||
|
@ -2154,14 +2160,14 @@ void NDArray::operator/=(const NDArray& other) {
|
|||
}
|
||||
|
||||
if (!this->isScalar() && other.isScalar()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
prepareSpecialUse({this}, {this, &other});
|
||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {this, &other});
|
||||
NDArray::registerSpecialUse({this}, {this, &other});
|
||||
}
|
||||
else{
|
||||
Nd4jLong *bShape = nullptr;
|
||||
|
@ -2264,9 +2270,9 @@ NDArray NDArray::operator-(const NDArray& other) const {
|
|||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this, &other});
|
||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({&result}, {this, &other});
|
||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -2285,9 +2291,9 @@ NDArray NDArray::operator*(const NDArray& other) const {
|
|||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this, &other});
|
||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({&result}, {this, &other});
|
||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -2308,9 +2314,9 @@ NDArray NDArray::operator/(const NDArray& other) const {
|
|||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this, &other});
|
||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({&result}, {this, &other});
|
||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -2326,9 +2332,9 @@ NDArray NDArray::operator-() const {
|
|||
|
||||
NDArray result(getShapeInfo(), false, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -2631,7 +2637,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& di
|
|||
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
||||
registerSpecialUse({result}, {this, other});
|
||||
NDArray::registerSpecialUse({result}, {this, other});
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2688,7 +2694,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
|||
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
||||
registerSpecialUse({result}, {this, other});
|
||||
NDArray::registerSpecialUse({result}, {this, other});
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2896,7 +2902,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
|||
Nd4jLong *shapeInfoNew;
|
||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
||||
bool canReshape = shape::reshapeC(this->rankOf(), this->_shapeInfo, shape.size(), shape.data(), shapeInfoNew);
|
||||
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
||||
|
||||
// we can do this only if there was no permute applied, or there are no weird strides
|
||||
if (canReshape) {
|
||||
|
@ -2948,11 +2954,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* othe
|
|||
if (target->dataType() != this->dataType() && target->dataType() != other->dataType())
|
||||
throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !");
|
||||
|
||||
prepareSpecialUse({target}, {this, other});
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||
|
||||
registerSpecialUse({target}, {this, other});
|
||||
NDArray::registerSpecialUse({target}, {this, other});
|
||||
|
||||
if (extraParams != nullptr)
|
||||
synchronize("NDArray::applyPairwiseTransform");
|
||||
|
@ -2969,9 +2973,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
|
|||
if (dataType() != other->dataType())
|
||||
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
|
||||
|
||||
prepareSpecialUse({target}, {this, other});
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||
registerSpecialUse({target}, {this, other});
|
||||
NDArray::registerSpecialUse({target}, {this, other});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3070,22 +3074,23 @@ void NDArray::assign(const NDArray& other) {
|
|||
if (other.isScalar()) {
|
||||
|
||||
if(this->isScalar()) {
|
||||
preparePrimaryUse({this}, {&other});
|
||||
NDArray::preparePrimaryUse({this}, {&other});
|
||||
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {&other});
|
||||
NDArray::registerPrimaryUse({this}, {&other});
|
||||
this->syncToDevice();
|
||||
}
|
||||
else {
|
||||
if (dataType() != other.dataType()) {
|
||||
auto tmp = other.cast(dataType());
|
||||
prepareSpecialUse({this}, {tmp});
|
||||
NDArray::prepareSpecialUse({this}, {tmp});
|
||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {});
|
||||
NDArray::registerSpecialUse({this}, {});
|
||||
delete tmp;
|
||||
}
|
||||
else {
|
||||
prepareSpecialUse({this}, {&other});
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||
registerSpecialUse({this}, {&other});
|
||||
NDArray::registerSpecialUse({this}, {&other});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3101,16 +3106,16 @@ void NDArray::assign(const NDArray& other) {
|
|||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||
else {
|
||||
prepareSpecialUse({this}, {&other});
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({this}, {&other});
|
||||
NDArray::registerSpecialUse({this}, {&other});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// This method returns new copy of this NDArray, optionally in different order
|
||||
NDArray* NDArray::dup(const char newOrder) {
|
||||
NDArray* NDArray::dup(const char newOrder) const {
|
||||
|
||||
if (isEmpty())
|
||||
return NDArrayFactory::empty_(dataType(), getContext());
|
||||
|
@ -3170,7 +3175,7 @@ std::string NDArray::e(const Nd4jLong i) const {
|
|||
if (!isS())
|
||||
throw std::runtime_error("Can't get std::string out of non-string array");
|
||||
|
||||
preparePrimaryUse({}, {this});
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
|
||||
// getting "virtual" offset. it's not real though,since it doesn't take lengths into account
|
||||
auto offset = getOffset(i);
|
||||
|
@ -3208,8 +3213,8 @@ T NDArray::e(const Nd4jLong i) const {
|
|||
|
||||
const auto rp = getOffset(i);
|
||||
|
||||
preparePrimaryUse({}, {this});
|
||||
registerPrimaryUse({}, {this});
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
|
@ -3226,8 +3231,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
|
|||
const Nd4jLong coords[2] = {i, j};
|
||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
|
||||
preparePrimaryUse({}, {this});
|
||||
registerPrimaryUse({}, {this});
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||
|
||||
|
@ -3246,8 +3251,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
|||
const Nd4jLong coords[3] = {i, j, k};
|
||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
|
||||
preparePrimaryUse({}, {this});
|
||||
registerPrimaryUse({}, {this});
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||
|
||||
|
@ -3266,8 +3271,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
|
|||
const Nd4jLong coords[4] = {i, j, k, l};
|
||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
|
||||
preparePrimaryUse({}, {this});
|
||||
registerPrimaryUse({}, {this});
|
||||
NDArray::preparePrimaryUse({}, {this});
|
||||
NDArray::registerPrimaryUse({}, {this});
|
||||
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||
|
||||
|
@ -3300,9 +3305,9 @@ void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, Extr
|
|||
if (!target->isR())
|
||||
throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types");
|
||||
|
||||
prepareSpecialUse({target}, {this});
|
||||
NDArray::prepareSpecialUse({target}, {this});
|
||||
NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({target}, {this});
|
||||
NDArray::registerSpecialUse({target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3314,9 +3319,9 @@ void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraA
|
|||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
prepareSpecialUse({target}, {this});
|
||||
NDArray::prepareSpecialUse({target}, {this});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({target}, {this});
|
||||
NDArray::registerSpecialUse({target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3331,9 +3336,9 @@ void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, Extra
|
|||
if (target->dataType() != dataType())
|
||||
throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array");
|
||||
|
||||
prepareSpecialUse({target}, {this});
|
||||
NDArray::prepareSpecialUse({target}, {this});
|
||||
NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({target}, {this});
|
||||
NDArray::registerSpecialUse({target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3347,9 +3352,9 @@ void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, Ext
|
|||
if (!this->isR() || !target->isR() || (this->dataType() != target->dataType()))
|
||||
throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !");
|
||||
|
||||
registerSpecialUse({target}, {this});
|
||||
NDArray::prepareSpecialUse({target}, {this});
|
||||
NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||
prepareSpecialUse({target}, {this});
|
||||
NDArray::registerSpecialUse({target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3363,9 +3368,9 @@ void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, Extra
|
|||
if (!target->isB())
|
||||
throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types");
|
||||
|
||||
prepareSpecialUse({target}, {this});
|
||||
NDArray::prepareSpecialUse({target}, {this});
|
||||
NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||
registerSpecialUse({target}, {this});
|
||||
NDArray::registerSpecialUse({target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3375,9 +3380,9 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons
|
|||
|
||||
NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext());
|
||||
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3389,9 +3394,9 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const
|
|||
|
||||
NDArray result(getShapeInfo(), false, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3403,9 +3408,9 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con
|
|||
|
||||
NDArray result(getShapeInfo(), false, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3417,9 +3422,9 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
|
|||
|
||||
NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext());
|
||||
|
||||
prepareSpecialUse({&result}, {this});
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||
registerSpecialUse({&result}, {this});
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3435,9 +3440,9 @@ void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArra
|
|||
if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType()))
|
||||
throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!");
|
||||
|
||||
prepareSpecialUse({target}, {this, scalar});
|
||||
NDArray::prepareSpecialUse({target}, {this, scalar});
|
||||
NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
||||
registerSpecialUse({target}, {this, scalar});
|
||||
NDArray::registerSpecialUse({target}, {this, scalar});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3471,10 +3476,9 @@ void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, ND
|
|||
throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!");
|
||||
}
|
||||
|
||||
prepareSpecialUse({target}, {this, scalar});
|
||||
NDArray::prepareSpecialUse({target}, {this, scalar});
|
||||
NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
||||
|
||||
registerSpecialUse({target}, {this, scalar});
|
||||
NDArray::registerSpecialUse({target}, {this, scalar});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -3557,7 +3561,7 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons
|
|||
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo());
|
||||
registerSpecialUse({result}, {this, other});
|
||||
NDArray::registerSpecialUse({result}, {this, other});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3635,9 +3639,9 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c
|
|||
|
||||
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
|
||||
prepareSpecialUse({result}, {this, other});
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||
registerSpecialUse({result}, {this, other});
|
||||
NDArray::registerSpecialUse({result}, {this, other});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -3780,9 +3784,9 @@ void NDArray::p(const Nd4jLong i, const T value) {
|
|||
auto rp = getOffset(i);
|
||||
const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value));
|
||||
|
||||
preparePrimaryUse({this}, {}, true);
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {});
|
||||
NDArray::registerPrimaryUse({this}, {});
|
||||
}
|
||||
|
||||
template void NDArray::p(const Nd4jLong i, const double value);
|
||||
|
@ -3811,9 +3815,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
|
|||
Nd4jLong coords[2] = {i, j};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
|
||||
preparePrimaryUse({this}, {}, true);
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {});
|
||||
NDArray::registerPrimaryUse({this}, {});
|
||||
}
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
|
||||
|
@ -3837,13 +3841,13 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
|
|||
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
|
||||
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
|
||||
|
||||
preparePrimaryUse({this}, {}, true);
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
|
||||
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
||||
Nd4jLong coords[3] = {i, j, k};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {});
|
||||
NDArray::registerPrimaryUse({this}, {});
|
||||
}
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
|
||||
|
@ -3870,9 +3874,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
|
|||
Nd4jLong coords[4] = {i, j, k, l};
|
||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||
|
||||
preparePrimaryUse({this}, {}, true);
|
||||
NDArray::preparePrimaryUse({this}, {}, true);
|
||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {});
|
||||
NDArray::registerPrimaryUse({this}, {});
|
||||
}
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
|
||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
|
||||
|
@ -3896,10 +3900,10 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
|||
if (i >= _length)
|
||||
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
||||
|
||||
preparePrimaryUse({this}, {&scalar}, true);
|
||||
NDArray::preparePrimaryUse({this}, {&scalar}, true);
|
||||
auto rp = getOffset(i);
|
||||
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
|
||||
registerPrimaryUse({this}, {&scalar});
|
||||
NDArray::registerPrimaryUse({this}, {&scalar});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -4195,7 +4199,7 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
|
|||
|
||||
|
||||
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
|
||||
auto numTads = lengthOf() / shape::length(pack.primaryShapeInfo());
|
||||
auto numTads = pack.numberOfTads();
|
||||
|
||||
for (int idx = 0; idx < numTads; idx++ ) {
|
||||
auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset());
|
||||
|
|
|
@ -1578,6 +1578,20 @@ public:
|
|||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
bool descending);
|
||||
|
||||
void sortByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending);
|
||||
|
||||
void sortByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending);
|
||||
|
||||
void sortTad(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
|
@ -1587,6 +1601,24 @@ public:
|
|||
Nd4jLong *tadOffsets,
|
||||
bool descending);
|
||||
|
||||
void sortTadByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending);
|
||||
|
||||
void sortTadByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending);
|
||||
|
||||
|
||||
// special sort impl for sorting out COO indices and values
|
||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
|
||||
|
|
|
@ -208,6 +208,23 @@ void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::specialBuffer() {
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::getSpecialBuffer() const {
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// change an array by repeating it the number of times given by reps.
|
||||
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||
|
|
|
@ -27,6 +27,52 @@
|
|||
|
||||
namespace nd4j {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||
|
||||
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
|
||||
|
||||
if (descriptor.arrLength() != data.size()) {
|
||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||
}
|
||||
|
||||
bool* hostBuffer = nullptr;
|
||||
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
|
||||
std::copy(data.begin(), data.end(), hostBuffer);
|
||||
|
||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
|
||||
|
||||
NDArray result(buffer, descriptor, context);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||
|
||||
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
|
||||
|
||||
if (descriptor.arrLength() != data.size()) {
|
||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||
}
|
||||
|
||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
|
||||
|
||||
NDArray result(buffer, descriptor, context);
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
|
||||
std::string s(str);
|
||||
|
@ -227,10 +273,13 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd
|
|||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned int> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned long> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint16_t> &data, nd4j::LaunchContext * context);
|
||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context);
|
||||
|
||||
|
||||
|
@ -391,6 +440,7 @@ template NDArray NDArrayFactory::create(const std::vector<bfloat16> &values, nd4
|
|||
template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<uint16_t> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context);
|
||||
template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context);
|
||||
|
@ -452,53 +502,6 @@ template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::L
|
|||
return new NDArray(order, shape, dataType, context);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||
|
||||
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
|
||||
|
||||
if (descriptor.arrLength() != data.size()) {
|
||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||
}
|
||||
|
||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
|
||||
|
||||
NDArray result(buffer, descriptor, context);
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||
|
||||
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
|
||||
|
||||
if (descriptor.arrLength() != data.size()) {
|
||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||
}
|
||||
|
||||
bool* hostBuffer = nullptr;
|
||||
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
|
||||
std::copy(data.begin(), data.end(), hostBuffer);
|
||||
|
||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
|
||||
|
||||
NDArray result(buffer, descriptor, context);
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) {
|
||||
|
|
|
@ -2736,6 +2736,60 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
|||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
}
|
||||
|
||||
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending) {
|
||||
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending) {
|
||||
|
||||
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending) {
|
||||
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dx, Nd4jLong *dxShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending) {
|
||||
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||
|
|
|
@ -192,8 +192,8 @@ void NDArray::setIdentity() {
|
|||
if (isS())
|
||||
throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!");
|
||||
|
||||
if (rankOf() != 2)
|
||||
throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
|
||||
// if (rankOf() != 2)
|
||||
// throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
@ -234,12 +234,15 @@ void NDArray::synchronize(const char* msg) const {
|
|||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
|
||||
for (const auto& a : readList)
|
||||
a->syncToDevice();
|
||||
if(a != nullptr)
|
||||
a->syncToDevice();
|
||||
|
||||
for (const auto& a : writeList) {
|
||||
a->getDataBuffer()->allocateSpecial();
|
||||
if (synchronizeWritables)
|
||||
a->syncToDevice();
|
||||
if (a != nullptr) {
|
||||
a->getDataBuffer()->allocateSpecial();
|
||||
if (synchronizeWritables)
|
||||
a->syncToDevice();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -247,22 +250,27 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
|
|||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
|
||||
for (const auto& p : readList)
|
||||
p->tickReadDevice();
|
||||
if(p != nullptr)
|
||||
p->tickReadDevice();
|
||||
|
||||
for (const auto& p : writeList)
|
||||
p->tickWriteDevice();
|
||||
if (p != nullptr)
|
||||
p->tickWriteDevice();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
|
||||
for (const auto& a : readList)
|
||||
if(a != nullptr)
|
||||
a->syncToHost();
|
||||
|
||||
for (const auto& a : writeList) {
|
||||
a->getDataBuffer()->allocatePrimary();
|
||||
if (synchronizeWritables)
|
||||
a->syncToHost();
|
||||
if (a != nullptr) {
|
||||
a->getDataBuffer()->allocatePrimary();
|
||||
if (synchronizeWritables)
|
||||
a->syncToHost();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -270,10 +278,12 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
|
|||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
|
||||
for (const auto& p : readList)
|
||||
p->tickReadHost();
|
||||
if(p != nullptr)
|
||||
p->tickReadHost();
|
||||
|
||||
for (const auto& p : writeList)
|
||||
p->tickWriteHost();
|
||||
if (p != nullptr)
|
||||
p->tickWriteHost();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -427,9 +437,26 @@ void NDArray::repeat(int dimension, NDArray& target) const {
|
|||
NDArray::registerSpecialUse({&target}, {this});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::specialBuffer() {
|
||||
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void* NDArray::getSpecialBuffer() const {
|
||||
if (_buffer->special() == nullptr)
|
||||
return getBuffer();
|
||||
// FIXME: this should be fixed once CUDA backend added
|
||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {\
|
||||
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {
|
||||
|
||||
if(_length == 0)
|
||||
{ printf("NDArray::printActualBuffer: array length is zero !\n"); return; }
|
||||
|
@ -477,7 +504,7 @@ template void NDArray::printCurrentBuffer<double>(const bool host, const char* m
|
|||
|
||||
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
|
||||
|
||||
#include <cpu/NDArrayLambda.hpp>
|
||||
//#include <cpu/NDArrayLambda.hpp>
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
@ -2321,6 +2321,163 @@ void NativeOps::sort(Nd4jPointer *extraPointers,
|
|||
}
|
||||
|
||||
|
||||
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending) {
|
||||
|
||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
|
||||
auto xLength = shape::length(xShapeInfo);
|
||||
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
|
||||
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||
int numBlocks = xLength / numThreads;
|
||||
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||
numBlocks++;
|
||||
|
||||
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||
|
||||
for (int k = 2; k <= xLength; k = 2*k) {
|
||||
for (int j = k >> 1; j > 0; j = j >> 1) {
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||
int numBlocks = xLength / numThreads;
|
||||
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||
numBlocks++;
|
||||
|
||||
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
|
||||
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||
|
||||
int max = 2, dg = 0;
|
||||
while (max < xLength) {
|
||||
max <<= 1;
|
||||
dg++;
|
||||
}
|
||||
max <<= 1;
|
||||
|
||||
for (int window = 2; window < max; window<<=1) {
|
||||
int n = window;
|
||||
int rev = 0;
|
||||
do{
|
||||
int half = n >> 1;
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
n>>=1;
|
||||
rev = 1;
|
||||
} while(n > 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
bool descending) {
|
||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
|
||||
auto xLength = shape::length(xShapeInfo);
|
||||
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
|
||||
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||
int numBlocks = xLength / numThreads;
|
||||
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||
numBlocks++;
|
||||
|
||||
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||
|
||||
for (int k = 2; k <= xLength; k = 2*k) {
|
||||
for (int j = k >> 1; j > 0; j = j >> 1) {
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||
int numBlocks = xLength / numThreads;
|
||||
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||
numBlocks++;
|
||||
|
||||
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
|
||||
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||
|
||||
int max = 2, dg = 0;
|
||||
while (max < xLength) {
|
||||
max <<= 1;
|
||||
dg++;
|
||||
}
|
||||
max <<= 1;
|
||||
|
||||
for (int window = 2; window < max; window<<=1) {
|
||||
int n = window;
|
||||
int rev = 0;
|
||||
do{
|
||||
int half = n >> 1;
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
n>>=1;
|
||||
rev = 1;
|
||||
} while(n > 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending) {
|
||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed");
|
||||
}
|
||||
|
||||
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dy, Nd4jLong *dyShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
bool descending) {
|
||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed");
|
||||
}
|
||||
|
||||
|
||||
void NativeOps::sortTad(Nd4jPointer *extraPointers,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
|
@ -2331,15 +2488,13 @@ void NativeOps::sortTad(Nd4jPointer *extraPointers,
|
|||
bool descending) {
|
||||
// to be implemented
|
||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
|
||||
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
dim3 launchDims(tadPack.numberOfTads(), 1024, 33768);
|
||||
|
||||
dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
||||
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "sortTadFloat(...) failed");
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed");
|
||||
}
|
||||
|
||||
void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
||||
|
|
|
@ -38,11 +38,11 @@ namespace nd4j {
|
|||
ConstantDataBuffer() = default;
|
||||
~ConstantDataBuffer() = default;
|
||||
|
||||
Nd4jLong sizeOf();
|
||||
Nd4jLong length();
|
||||
Nd4jLong sizeOf() const;
|
||||
Nd4jLong length() const;
|
||||
|
||||
Nd4jPointer primary();
|
||||
Nd4jPointer special();
|
||||
Nd4jPointer primary() const;
|
||||
Nd4jPointer special() const;
|
||||
|
||||
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
|
||||
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;
|
||||
|
|
|
@ -261,6 +261,8 @@ DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
|
|||
|
||||
allocateBuffers();
|
||||
copyBufferFrom(other);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -285,6 +287,8 @@ DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
|
|||
other._primaryBuffer = other._specialBuffer = nullptr;
|
||||
other.setAllocFlags(false, false);
|
||||
other._lenInBytes = 0;
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include <op_boilerplate.h>
|
||||
#include <dll.h>
|
||||
#include <Environment.h>
|
||||
#include <ArrayOptions.h>
|
||||
#include <ArrayOptions.h>
|
||||
#include <templatemath.h>
|
||||
#include <shape.h>
|
||||
#include <helpers/logger.h>
|
||||
|
@ -62,7 +62,7 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
FORCEINLINE static _CUDA_HD T nanOrZero();
|
||||
|
||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||
template <typename T>
|
||||
FORCEINLINE static T eps();
|
||||
|
||||
|
@ -94,13 +94,13 @@ namespace nd4j {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
///// IMLEMENTATION OF INLINE METHODS /////
|
||||
///// IMLEMENTATION OF INLINE METHODS /////
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
FORCEINLINE nd4j::DataType DataTypeUtils::pickFloatingType(nd4j::DataType typeX) {
|
||||
// if proposed dataType is already floating point - return it
|
||||
if (isR(typeX))
|
||||
return typeX;
|
||||
return typeX;
|
||||
return Environment::getInstance()->defaultFloatDataType();
|
||||
}
|
||||
|
||||
|
@ -213,13 +213,13 @@ FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::min<uint32_t>() {
|
|||
}
|
||||
|
||||
template<>
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::min<float>() {
|
||||
return 1.175494e-38;
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::min<float>() {
|
||||
return 1.175494e-38;
|
||||
}
|
||||
|
||||
template<>
|
||||
FORCEINLINE _CUDA_HD float16 DataTypeUtils::min<float16>() {
|
||||
return (float16) 6.1035e-05;
|
||||
return (float16) 6.1035e-05;
|
||||
}
|
||||
|
||||
template<>
|
||||
|
@ -228,8 +228,8 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::min<bfloat16>() {
|
|||
}
|
||||
|
||||
template<>
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::min<double>() {
|
||||
return 2.2250738585072014e-308;
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::min<double>() {
|
||||
return 2.2250738585072014e-308;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -280,17 +280,17 @@ FORCEINLINE _CUDA_HD Nd4jULong DataTypeUtils::max<Nd4jULong>() {
|
|||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::max<float>() {
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::max<float>() {
|
||||
return 3.402823e+38;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::max<double>() {
|
||||
return 1.7976931348623157E308;
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::max<double>() {
|
||||
return 1.7976931348623157E308;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD float16 DataTypeUtils::max<float16>() {
|
||||
FORCEINLINE _CUDA_HD float16 DataTypeUtils::max<float16>() {
|
||||
return static_cast<float16>(65504.f);
|
||||
}
|
||||
|
||||
|
@ -335,6 +335,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
return std::string("INT8");
|
||||
case INT16:
|
||||
return std::string("INT16");
|
||||
case UINT16:
|
||||
return std::string("UINT16");
|
||||
case INT32:
|
||||
return std::string("INT32");
|
||||
case INT64:
|
||||
|
@ -361,7 +363,7 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
|
||||
template <typename T>
|
||||
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
|
||||
|
||||
|
||||
for (int e = 0; e < shape::shapeInfoLength(originalShapeInfo); e++) {
|
||||
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
|
||||
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
|
||||
|
@ -373,9 +375,9 @@ FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo,
|
|||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||
template <typename T>
|
||||
FORCEINLINE T DataTypeUtils::eps() {
|
||||
FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||
if (std::is_same<T, double>::value)
|
||||
return std::numeric_limits<double>::epsilon();
|
||||
else if (std::is_same<T, float>::value)
|
||||
|
@ -406,7 +408,7 @@ FORCEINLINE T DataTypeUtils::eps() {
|
|||
case nd4j::DataType::FLOAT8:
|
||||
case nd4j::DataType::QINT8:
|
||||
case nd4j::DataType::BOOL: return (size_t) 1;
|
||||
|
||||
|
||||
case nd4j::DataType::BFLOAT16:
|
||||
case nd4j::DataType::HALF:
|
||||
case nd4j::DataType::INT16:
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <vector>
|
||||
#include <array/DataType.h>
|
||||
#include <pointercast.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT ExtraArguments {
|
||||
|
|
|
@ -35,21 +35,21 @@ namespace nd4j {
|
|||
TadPack() = default;
|
||||
~TadPack() = default;
|
||||
|
||||
Nd4jLong* primaryShapeInfo();
|
||||
Nd4jLong* primaryOffsets();
|
||||
Nd4jLong* primaryShapeInfo() const;
|
||||
Nd4jLong* primaryOffsets() const;
|
||||
|
||||
Nd4jLong* specialShapeInfo();
|
||||
Nd4jLong* specialOffsets();
|
||||
Nd4jLong* specialShapeInfo() const;
|
||||
Nd4jLong* specialOffsets() const;
|
||||
|
||||
Nd4jLong numberOfTads();
|
||||
int shapeInfoLength();
|
||||
Nd4jLong numberOfTads() const;
|
||||
int shapeInfoLength() const;
|
||||
|
||||
/**
|
||||
* These methods return either primary or special pointers depending on platform binaries were compiled for
|
||||
* @return
|
||||
*/
|
||||
Nd4jLong *platformShapeInfo();
|
||||
Nd4jLong *platformOffsets();
|
||||
Nd4jLong *platformShapeInfo() const;
|
||||
Nd4jLong *platformOffsets() const;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -28,19 +28,19 @@ namespace nd4j {
|
|||
_sizeOf = sizeOf;
|
||||
}
|
||||
|
||||
Nd4jPointer ConstantDataBuffer::primary() {
|
||||
Nd4jPointer ConstantDataBuffer::primary() const {
|
||||
return _primaryBuffer;
|
||||
}
|
||||
|
||||
Nd4jPointer ConstantDataBuffer::special() {
|
||||
Nd4jPointer ConstantDataBuffer::special() const {
|
||||
return _specialBuffer;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantDataBuffer::sizeOf() {
|
||||
Nd4jLong ConstantDataBuffer::sizeOf() const {
|
||||
return _sizeOf;
|
||||
}
|
||||
|
||||
Nd4jLong ConstantDataBuffer::length() {
|
||||
Nd4jLong ConstantDataBuffer::length() const {
|
||||
return _length;
|
||||
}
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ namespace nd4j {
|
|||
NDArray* NDArrayList::readRaw(int idx) {
|
||||
if (_chunks.count(idx) < 1) {
|
||||
nd4j_printf("Non-existent chunk requested: [%i]\n", idx);
|
||||
throw std::runtime_error("Bad index");
|
||||
throw std::invalid_argument("Bad index");
|
||||
}
|
||||
|
||||
return _chunks[idx];
|
||||
|
@ -120,7 +120,7 @@ namespace nd4j {
|
|||
// storing reference
|
||||
_chunks[idx] = array;
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong>& NDArrayList::shape() {
|
||||
|
@ -152,8 +152,10 @@ namespace nd4j {
|
|||
std::vector<bool> bargs;
|
||||
int numElements = _elements.load();
|
||||
|
||||
for (int e = 0; e < numElements; e++)
|
||||
for (int e = 0; e < numElements; e++) {
|
||||
_chunks[e]->syncToDevice();
|
||||
inputs.emplace_back(_chunks[e]);
|
||||
}
|
||||
|
||||
iargs.push_back(_axis);
|
||||
|
||||
|
|
|
@ -29,34 +29,34 @@ namespace nd4j {
|
|||
_numTads = numTads;
|
||||
}
|
||||
|
||||
Nd4jLong* TadPack::primaryShapeInfo() {
|
||||
Nd4jLong* TadPack::primaryShapeInfo() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
|
||||
}
|
||||
Nd4jLong* TadPack::primaryOffsets() {
|
||||
Nd4jLong* TadPack::primaryOffsets() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
|
||||
}
|
||||
|
||||
Nd4jLong* TadPack::specialShapeInfo() {
|
||||
Nd4jLong* TadPack::specialShapeInfo() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
|
||||
}
|
||||
|
||||
Nd4jLong* TadPack::specialOffsets() {
|
||||
Nd4jLong* TadPack::specialOffsets() const {
|
||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
|
||||
}
|
||||
|
||||
Nd4jLong TadPack::numberOfTads() {
|
||||
Nd4jLong TadPack::numberOfTads() const {
|
||||
return _numTads;
|
||||
}
|
||||
|
||||
Nd4jLong* TadPack::platformShapeInfo() {
|
||||
Nd4jLong* TadPack::platformShapeInfo() const {
|
||||
return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
|
||||
}
|
||||
|
||||
Nd4jLong* TadPack::platformOffsets() {
|
||||
Nd4jLong* TadPack::platformOffsets() const {
|
||||
return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
|
||||
}
|
||||
|
||||
int TadPack::shapeInfoLength() {
|
||||
int TadPack::shapeInfoLength() const {
|
||||
return (int) shape::shapeInfoLength(primaryShapeInfo());
|
||||
}
|
||||
}
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
|||
class AttentionHelper {
|
||||
|
||||
public:
|
||||
static nd4j::NDArray* multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static nd4j::NDArray multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
};
|
||||
}
|
||||
|
|
|
@ -69,10 +69,10 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
void executeOnce() override {
|
||||
auto xT = (_tA ? _x->transpose() : _x);
|
||||
auto yT = (_tB ? _y->transpose() : _y);
|
||||
auto xT = (_tA ? _x->transpose() : *_x);
|
||||
auto yT = (_tB ? _y->transpose() : *_y);
|
||||
|
||||
MmulHelper::mmul(xT, yT, _z, _alpha, _beta);
|
||||
MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta);
|
||||
}
|
||||
|
||||
std::string axis() override {
|
||||
|
|
|
@ -39,31 +39,31 @@ NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
|
|||
|
||||
T coeff;
|
||||
T normX = x.reduceNumber(reduce::Norm2).e<T>(0);
|
||||
|
||||
|
||||
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
|
||||
|
||||
normX = x.e<T>(0);
|
||||
coeff = 0.f;
|
||||
w = 0.f;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
else {
|
||||
|
||||
|
||||
if(x.e<T>(0) >= (T)0.f)
|
||||
normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
|
||||
T u0 = x.e<T>(0) - normX;
|
||||
coeff = -u0 / normX;
|
||||
w.assign(x / u0);
|
||||
coeff = -u0 / normX;
|
||||
w.assign(x / u0);
|
||||
}
|
||||
|
||||
|
||||
w.p(Nd4jLong(0), 1.f);
|
||||
wT.assign(&w);
|
||||
|
||||
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
|
||||
identity.setIdentity(); // identity matrix
|
||||
|
||||
return identity - mmul(w, wT) * coeff;
|
||||
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
|
||||
identity.setIdentity(); // identity matrix
|
||||
|
||||
return identity - mmul(w, wT) * coeff;
|
||||
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,7 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
|
|||
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input tail vector must have length less than unity compared to input x vector!");
|
||||
|
||||
normX = x.reduceNumber(reduce::Norm2, nullptr).e<T>(0);
|
||||
|
||||
|
||||
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
|
||||
|
||||
normX = x.e<T>(0);
|
||||
|
@ -87,18 +87,18 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
|
|||
tail = (T)0.f;
|
||||
}
|
||||
else {
|
||||
|
||||
|
||||
if(x.e<T>(0) >= (T)0.f)
|
||||
normX = -normX; // choose opposite sign to lessen roundoff error
|
||||
|
||||
|
||||
T u0 = x.e<T>(0) - normX;
|
||||
coeff = -u0 / normX;
|
||||
coeff = -u0 / normX;
|
||||
|
||||
if(x.isRowVector())
|
||||
tail.assign(x({0,0, 1,-1}) / u0);
|
||||
tail.assign(x({0,0, 1,-1}) / u0);
|
||||
else
|
||||
tail.assign(x({1,-1, 0,0,}) / u0);
|
||||
}
|
||||
tail.assign(x({1,-1, 0,0,}) / u0);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -107,20 +107,20 @@ void Householder<T>::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) {
|
|||
|
||||
int rows = (int)x.lengthOf()-1;
|
||||
int num = 1;
|
||||
|
||||
|
||||
if(rows == 0) {
|
||||
rows = 1;
|
||||
num = 0;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto tail = NDArrayFactory::create(x.ordering(), {rows, 1}, x.dataType(), x.getContext());
|
||||
evalHHmatrixData(x, tail, coeff, normX);
|
||||
|
||||
if(x.isRowVector()) {
|
||||
auto temp = x({0,0, num, x.sizeAt(1)}, true);
|
||||
temp.assign(tail);
|
||||
temp.assign(tail);
|
||||
}
|
||||
else {
|
||||
else {
|
||||
auto temp = x({num,x.sizeAt(0), 0,0}, true);
|
||||
temp.assign(tail);
|
||||
}
|
||||
|
@ -129,14 +129,14 @@ void Householder<T>::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff) {
|
||||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(0) == 1)
|
||||
matrix *= (T)1.f - coeff;
|
||||
|
||||
else if(coeff != (T)0.f) {
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(0) == 1) {
|
||||
matrix *= (T) 1.f - coeff;
|
||||
}
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
|
||||
auto bottomPartCopy = *bottomPart;
|
||||
|
@ -145,26 +145,22 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
|
|||
|
||||
auto column = tail;
|
||||
auto row = tail.transpose();
|
||||
auto resultingRow = mmul(*row, bottomPartCopy);
|
||||
auto resultingRow = mmul(row, bottomPartCopy);
|
||||
auto fistRow = matrix({0,1, 0,0}, true);
|
||||
resultingRow += fistRow;
|
||||
fistRow -= resultingRow * coeff;
|
||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||
|
||||
delete row;
|
||||
resultingRow += fistRow;
|
||||
fistRow -= resultingRow * coeff;
|
||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||
}
|
||||
else {
|
||||
|
||||
|
||||
auto row = tail;
|
||||
auto column = tail.transpose();
|
||||
auto resultingRow = mmul(row, bottomPartCopy);
|
||||
auto fistRow = matrix({0,1, 0,0}, true);
|
||||
resultingRow += fistRow;
|
||||
fistRow -= resultingRow * coeff;
|
||||
*bottomPart -= mmul(*column, resultingRow) * coeff;
|
||||
|
||||
delete column;
|
||||
}
|
||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||
}
|
||||
delete bottomPart;
|
||||
}
|
||||
}
|
||||
|
@ -176,10 +172,10 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
|
|||
|
||||
// if(matrix.rankOf() != 2)
|
||||
// throw "ops::helpers::Householder::mulRight method: input array must be 2D matrix !";
|
||||
|
||||
if(matrix.sizeAt(1) == 1)
|
||||
|
||||
if(matrix.sizeAt(1) == 1)
|
||||
matrix *= (T)1.f - coeff;
|
||||
|
||||
|
||||
else if(coeff != (T)0.f) {
|
||||
|
||||
auto rightPart = new NDArray(matrix({0,0, 1,matrix.sizeAt(1)}, true));
|
||||
|
@ -191,30 +187,25 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
|
|||
auto column = tail;
|
||||
auto row = tail.transpose();
|
||||
auto resultingCol = mmul(rightPartCopy, column);
|
||||
resultingCol += *fistCol;
|
||||
*fistCol -= resultingCol * coeff;
|
||||
*rightPart -= mmul(resultingCol, *row) * coeff;
|
||||
|
||||
delete row;
|
||||
}
|
||||
else {
|
||||
|
||||
auto row = tail;
|
||||
auto column = tail.transpose();
|
||||
auto resultingCol = mmul(rightPartCopy, *column);
|
||||
resultingCol += *fistCol;
|
||||
resultingCol += *fistCol;
|
||||
*fistCol -= resultingCol * coeff;
|
||||
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||
}
|
||||
else {
|
||||
|
||||
delete column;
|
||||
|
||||
}
|
||||
auto row = tail;
|
||||
auto column = tail.transpose();
|
||||
auto resultingCol = mmul(rightPartCopy, column);
|
||||
resultingCol += *fistCol;
|
||||
*fistCol -= resultingCol * coeff;
|
||||
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||
}
|
||||
delete rightPart;
|
||||
delete fistCol;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template class ND4J_EXPORT Householder<float>;
|
||||
template class ND4J_EXPORT Householder<float16>;
|
||||
template class ND4J_EXPORT Householder<bfloat16>;
|
||||
|
|
|
@ -157,8 +157,7 @@ bool JacobiSVD<T>::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) {
|
|||
|
||||
if(_calcU) {
|
||||
auto temp2 = rotation.transpose();
|
||||
mulRotationOnRight(p, q, _u, *temp2);
|
||||
delete temp2;
|
||||
mulRotationOnRight(p, q, _u, temp2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -251,9 +250,7 @@ void JacobiSVD<T>::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDA
|
|||
m.p<T>(1, 1, _z);
|
||||
|
||||
auto temp = right.transpose();
|
||||
left.assign(mmul(rotation, *temp));
|
||||
delete temp;
|
||||
|
||||
left.assign(mmul(rotation, temp));
|
||||
}
|
||||
|
||||
|
||||
|
@ -289,7 +286,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
else if(_rows < _cols) {
|
||||
|
||||
auto matrixT = matrix.transpose();
|
||||
HHcolPivQR qr(*matrixT / scale);
|
||||
HHcolPivQR qr(matrixT / scale);
|
||||
_m.assign(qr._qr({0,_rows, 0,_rows}));
|
||||
_m.fillAsTriangular<T>(0., 0, 0, 'l');
|
||||
_m.transposei();
|
||||
|
@ -305,8 +302,6 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
|
||||
if(_calcU)
|
||||
_u.assign(qr._permut);
|
||||
|
||||
delete matrixT;
|
||||
}
|
||||
else {
|
||||
|
||||
|
@ -352,8 +347,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
|||
|
||||
if(_calcU) {
|
||||
auto temp = rotLeft.transpose();
|
||||
mulRotationOnRight(p, q, _u, *temp);
|
||||
delete temp;
|
||||
mulRotationOnRight(p, q, _u, temp);
|
||||
}
|
||||
|
||||
mulRotationOnRight(p, q, _m, rotRight);
|
||||
|
|
|
@ -920,7 +920,7 @@ void SVD<T>::evalData(const NDArray& matrix) {
|
|||
auto temp1 = biDiag._HHbidiag.transpose();
|
||||
auto temp2 = _m({0,_diagSize, 0,0}, true);
|
||||
temp2.assign(temp1);
|
||||
delete temp1;
|
||||
|
||||
|
||||
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
|
||||
temp3.assign(0.);
|
||||
|
|
|
@ -184,9 +184,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
|
||||
if(pC->ordering() != 'f') {
|
||||
auto temp = pA;
|
||||
pA = pB ->permute({1,0});
|
||||
pB = temp->permute({1,0});
|
||||
pC = pC ->permute({1,0});
|
||||
pA = new NDArray(pB ->permute({1,0}));
|
||||
pB = new NDArray(temp->permute({1,0}));
|
||||
pC = new NDArray(pC ->permute({1,0}));
|
||||
toDelete.push_back(pA);
|
||||
toDelete.push_back(pB);
|
||||
toDelete.push_back(pC);
|
||||
|
@ -251,7 +251,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
||||
}
|
||||
|
||||
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES)
|
||||
}
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
||||
|
@ -339,7 +340,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
|||
threadsPerBlock.x = 512;
|
||||
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
||||
}
|
||||
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES)
|
||||
}
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
||||
|
@ -396,7 +398,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||
|
||||
NDArray::prepareSpecialUse({Z}, {X, Y});
|
||||
|
||||
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES)
|
||||
|
||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
||||
|
@ -406,8 +409,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||
return Z;
|
||||
}
|
||||
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||
|
||||
}
|
|
@ -28,33 +28,27 @@
|
|||
|
||||
namespace nd4j {
|
||||
|
||||
nd4j::NDArray *
|
||||
AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
|
||||
nd4j::NDArray AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
|
||||
auto miniBatchSize = input->sizeAt(0);
|
||||
auto seqLength = input->sizeAt(2);
|
||||
auto numHeads = projectionMatrix->sizeAt(0);
|
||||
auto projectedSize = projectionMatrix->sizeAt(1);
|
||||
|
||||
auto inputPerm = input->permute({1, 0, 2});
|
||||
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||
|
||||
NDArray* projected = new NDArray('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
|
||||
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
|
||||
nd4j::ops::matmul mmul;
|
||||
mmul.execute({projectionPrep, inputPrep}, {projected}, {}, {}, {});
|
||||
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
||||
|
||||
projected->reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||
projected->permutei({2, 0, 1, 3});
|
||||
|
||||
delete inputPerm;
|
||||
delete inputPrep;
|
||||
delete projectionPrep;
|
||||
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||
projected.permutei({2, 0, 1, 3});
|
||||
|
||||
return projected;
|
||||
}
|
||||
|
||||
void
|
||||
AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
|
||||
void AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
|
||||
const nd4j::NDArray *eps, nd4j::NDArray *dLdInput,
|
||||
nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) {
|
||||
auto miniBatchSize = input->sizeAt(0);
|
||||
|
@ -63,16 +57,16 @@ namespace nd4j {
|
|||
auto projectedSize = projectionMatrix->sizeAt(1);
|
||||
|
||||
auto epsPerm = eps->permute({1, 2, 0, 3});
|
||||
auto epsReshaped = epsPerm->reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
|
||||
auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
|
||||
|
||||
auto inputPerm = input->permute({1, 0, 2});
|
||||
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||
|
||||
nd4j::ops::matmul_bp mmulBp;
|
||||
NDArray dLdProjectionPrep(projectionPrep->shapeInfo(), false, context);
|
||||
NDArray dLdInputPrep(inputPrep->shapeInfo(), false, context);
|
||||
mmulBp.execute({projectionPrep, inputPrep, epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
||||
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
||||
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
||||
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
||||
|
||||
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
||||
|
@ -80,12 +74,6 @@ namespace nd4j {
|
|||
dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
|
||||
dLdInputPrep.permutei({1, 0, 2});
|
||||
dLdInput->assign(dLdInputPrep);
|
||||
|
||||
delete inputPerm;
|
||||
delete inputPrep;
|
||||
delete epsPerm;
|
||||
delete epsReshaped;
|
||||
delete projectionPrep;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
|||
|
||||
const int numInGradArrs = gradArrs.size();
|
||||
|
||||
// fill input gradient arrays in accordance to type of loss function
|
||||
// fill input gradient arrays in accordance to type of loss function
|
||||
switch(loss) {
|
||||
|
||||
case MEAN:
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
|
||||
for(int i = 0; i < numInGradArrs; ++i)
|
||||
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
|
||||
for(int i = 0; i < numInGradArrs; ++i)
|
||||
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
|
||||
break;
|
||||
|
||||
case SUM:
|
||||
|
@ -43,9 +43,9 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
|||
for(int i = 0; i < numInGradArrs; ++i)
|
||||
*gradArrs[i] = 1.;
|
||||
break;
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !");
|
||||
|
||||
default:
|
||||
throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
|||
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
|
||||
|
||||
const int numInArrsFF = argsHolderFF.getNumInArrs(); // also numInArrsFF = number of output arrays in opBP
|
||||
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
|
||||
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
||||
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
||||
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
||||
|
@ -61,10 +61,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
// fill input gradient arrays in accordance to type of loss function
|
||||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||
|
||||
// beck prop pass
|
||||
// beck prop pass
|
||||
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||
|
||||
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
||||
|
||||
for(int i = 0; i < numInArrsFF; ++i) { // loop through input array
|
||||
|
||||
if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false)
|
||||
|
@ -72,42 +73,42 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
|
||||
const Nd4jLong idxStart = static_cast<Nd4jLong>(idxRange[0] * inArrsFF[i]->lengthOf());
|
||||
const Nd4jLong idxEnd = static_cast<Nd4jLong>(idxRange[1] * inArrsFF[i]->lengthOf());
|
||||
|
||||
|
||||
for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array
|
||||
|
||||
double& elem = inArrsFF[i]->t<double>(j);
|
||||
const double orig = elem;
|
||||
const double orig = inArrsFF[i]->e<double>(j);
|
||||
|
||||
// add epsilon, feed forward
|
||||
elem = orig + EPSILON;
|
||||
inArrsFF[i]->p<double>(j, orig + EPSILON);
|
||||
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
|
||||
int numOutArrs = outArrsFF->size();
|
||||
double scorePlus = 0.;
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
|
||||
double scorePlus = 0.;
|
||||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scorePlus += tmpScalar.e<double>(0);
|
||||
}
|
||||
delete outArrsFF;
|
||||
|
||||
// subtract epsilon, feed forward
|
||||
elem = orig - EPSILON;
|
||||
inArrsFF[i]->p<double>(j, orig - EPSILON);
|
||||
outArrsFF = opFF.execute(argsHolderFF);
|
||||
double scoreMinus = 0.;
|
||||
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
|
||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||
if(loss == SUM)
|
||||
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||
else
|
||||
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
||||
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||
scoreMinus += tmpScalar.e<double>(0);
|
||||
}
|
||||
delete outArrsFF;
|
||||
|
||||
// restore initial element value
|
||||
elem = orig;
|
||||
inArrsFF[i]->p<double>(j, orig);
|
||||
|
||||
// calculate numerical gradient
|
||||
const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON);
|
||||
|
@ -116,7 +117,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
throw std::runtime_error("");
|
||||
}
|
||||
|
||||
// get analytical gradient
|
||||
// get analytical gradient
|
||||
const double analyticGrad = outArrsBP->at(i)->e<double>(j);
|
||||
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
|
||||
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
|
||||
|
@ -124,13 +125,13 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
}
|
||||
|
||||
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad);
|
||||
|
||||
|
||||
// calculate relative error
|
||||
double relError;
|
||||
if(numericalGrad == 0. && analyticGrad == 0.)
|
||||
relError = 0.;
|
||||
else
|
||||
relError = math::nd4j_abs<double>(analyticGrad - numericalGrad) / (math::nd4j_abs<double>(analyticGrad) + math::nd4j_abs<double>(numericalGrad));
|
||||
relError = math::nd4j_abs<double>(analyticGrad - numericalGrad) / (math::nd4j_abs<double>(analyticGrad) + math::nd4j_abs<double>(numericalGrad));
|
||||
|
||||
// verify result
|
||||
if(relError > MAXRELERR || std::isnan(relError)) {
|
||||
|
@ -144,7 +145,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
delete outArrsBP;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -39,26 +39,23 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* A, const nd4j::N
|
|||
nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<int>& axes_0, const std::vector<int>& axes_1) {
|
||||
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
|
||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
NDArray* aPR = a->permute(permutAt);
|
||||
NDArray* bPR = b->permute(permutBt);
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR->isSameShape(shapeAt))
|
||||
aPR->reshapei( shapeAt);
|
||||
if(!bPR->isSameShape(shapeBt))
|
||||
bPR->reshapei( shapeBt);
|
||||
NDArray aPR = a->permute(permutAt);
|
||||
NDArray bPR = b->permute(permutBt);
|
||||
|
||||
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||
// check whether reshape is necessary
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei( shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei( shapeBt);
|
||||
|
||||
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
||||
|
||||
c->reshapei(outShape);
|
||||
|
||||
delete aPR;
|
||||
delete bPR;
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
|
@ -74,65 +71,67 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
|||
|
||||
// check whether permutation is required
|
||||
if(!permutForC.empty())
|
||||
cP = c->permute(permutForC);
|
||||
cP = new NDArray(c->permute(permutForC));
|
||||
|
||||
auto aPR = a->permute(permutAt);
|
||||
auto bPR = b->permute(permutBt);
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR->isSameShape(shapeAt))
|
||||
aPR->reshapei(shapeAt);
|
||||
if(!bPR->isSameShape(shapeBt))
|
||||
bPR->reshapei(shapeBt);
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei(shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei(shapeBt);
|
||||
|
||||
if(!cP->isSameShape({aPR->sizeAt(0), bPR->sizeAt(1)}))
|
||||
cPR = cP->reshape(cP->ordering(), {aPR->sizeAt(0), bPR->sizeAt(1)});
|
||||
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
||||
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
||||
|
||||
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
||||
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
||||
|
||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||
cP->assign(cPR);
|
||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||
cP->assign(cPR);
|
||||
|
||||
if(cPR != c)
|
||||
delete cPR;
|
||||
if(cP != c)
|
||||
delete cP;
|
||||
delete aPR;
|
||||
delete bPR;
|
||||
}
|
||||
|
||||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) {
|
||||
|
||||
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
||||
std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception
|
||||
for(const auto& arr : modifA)
|
||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||
for(const auto& arr : modifB)
|
||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||
for(const auto& arr : modifC)
|
||||
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
|
||||
|
||||
for(const auto& arr : modifA)
|
||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||
for(const auto& arr : modifB)
|
||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||
for(const auto& arr : modifC)
|
||||
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
|
||||
|
||||
// first step for a array
|
||||
if(!whatToDoWithA.empty())
|
||||
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
|
||||
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
|
||||
// first step for b array
|
||||
if(!whatToDoWithB.empty())
|
||||
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
|
||||
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
|
||||
// rest steps for a array
|
||||
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
||||
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
||||
// rest steps for b array
|
||||
for(int i = 1; i < whatToDoWithB.size(); ++i)
|
||||
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
|
||||
|
||||
// now work with c array
|
||||
std::vector<NDArray*> cArrs = {c};
|
||||
if(!whatToDoWithC.empty()) {
|
||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? cArrs[i]->permute(modifC[i]) : cArrs[i]->reshape(c->ordering(), modifC[i]); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
}
|
||||
|
||||
|
||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||
|
||||
// check whether new buffer allocation was happened for c array
|
||||
|
@ -152,27 +151,30 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) {
|
||||
|
||||
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
||||
std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception
|
||||
for(const auto& arr : modifA)
|
||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||
for(const auto& arr : modifB)
|
||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||
|
||||
for(const auto& arr : modifA)
|
||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||
for(const auto& arr : modifB)
|
||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||
|
||||
// first step for a array
|
||||
if(!whatToDoWithA.empty())
|
||||
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
|
||||
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
|
||||
// first step for b array
|
||||
if(!whatToDoWithB.empty())
|
||||
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
|
||||
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
|
||||
// rest steps for a array
|
||||
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
||||
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
||||
// rest steps for b array
|
||||
for(int i = 1; i < whatToDoWithB.size(); ++i)
|
||||
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
|
||||
|
||||
|
||||
NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||
|
||||
|
||||
if(aPR != a)
|
||||
delete aPR;
|
||||
if(bPR != b)
|
||||
|
@ -281,9 +283,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
|
||||
|
||||
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
|
||||
|
||||
|
||||
if((transX && xRank > 1) || (transY && yRank > 1)) {
|
||||
const int rank = xRank >= yRank ? xRank : yRank;
|
||||
std::vector<int> permut(rank);
|
||||
|
@ -291,25 +293,25 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
permut[i] = i;
|
||||
permut[rank-2] = rank - 1;
|
||||
permut[rank-1] = rank - 2;
|
||||
|
||||
|
||||
if(transX)
|
||||
xT = x->permute(permut);
|
||||
xT = new NDArray(x->permute(permut));
|
||||
|
||||
if(transY)
|
||||
yT = y->permute(permut);
|
||||
yT = new NDArray(y->permute(permut));
|
||||
}
|
||||
|
||||
if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases
|
||||
|
||||
if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case
|
||||
xT = x->reshape(x->ordering(), {1, x->lengthOf()}); // please note x is not transposed in this case (since xRank=1)
|
||||
zT = z->reshape(z->ordering(), {1, z->lengthOf()});
|
||||
xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1)
|
||||
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
||||
}
|
||||
|
||||
|
||||
mmul(xT, yT, zT, 1., 0.);
|
||||
}
|
||||
else { // rest cases - batched mmul
|
||||
|
||||
|
||||
const int batchRank = xRank - 2;
|
||||
std::vector<int> dimsToExclude(batchRank);
|
||||
for(int i = 0; i < batchRank; ++i)
|
||||
|
@ -340,4 +342,4 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
|
@ -473,19 +473,9 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
|
|||
// FIXME: get rid of memcpy here
|
||||
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
|
||||
for (int i = 0; i < minRank; ++i)
|
||||
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
|
||||
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
|
||||
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
||||
|
||||
// nullify zero axis
|
||||
for (int e = 0; e < maxRank; e++)
|
||||
if (maxShapeInfo[e+1] == 0)
|
||||
tmpShapeInfo[e+1] = 0;
|
||||
|
||||
int delta = maxRank - minRank;
|
||||
for (int e = minRank - 1; e >= 0; e--)
|
||||
if (minShapeInfo[e + 1] == 0)
|
||||
tmpShapeInfo[e + 1 + delta] = 0;
|
||||
|
||||
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
||||
|
||||
if (shape::isEmpty(max) || shape::isEmpty(min)) {
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
|||
#ifdef __CUDACC__
|
||||
__host__
|
||||
#endif
|
||||
void Logger::printv(const char *format, std::vector<int>& vec) {
|
||||
void Logger::printv(const char *format, const std::vector<int>& vec) {
|
||||
printf("%s: {", format);
|
||||
for(int e = 0; e < vec.size(); e++) {
|
||||
auto v = vec[e];
|
||||
|
@ -55,7 +55,7 @@ namespace nd4j {
|
|||
#ifdef __CUDACC__
|
||||
__host__
|
||||
#endif
|
||||
void Logger::printv(const char *format, std::vector<Nd4jLong>& vec) {
|
||||
void Logger::printv(const char *format, const std::vector<Nd4jLong>& vec) {
|
||||
printf("%s: {", format);
|
||||
for(int e = 0; e < vec.size(); e++) {
|
||||
auto v = vec[e];
|
||||
|
|
|
@ -55,8 +55,8 @@ namespace nd4j {
|
|||
|
||||
static void _CUDA_H info(const char *format, ...);
|
||||
|
||||
static void _CUDA_H printv(const char *format, std::vector<int>& vec);
|
||||
static void _CUDA_H printv(const char *format, std::vector<Nd4jLong>& vec);
|
||||
static void _CUDA_H printv(const char *format, const std::vector<int>& vec);
|
||||
static void _CUDA_H printv(const char *format, const std::vector<Nd4jLong>& vec);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -1023,23 +1023,6 @@ namespace shape {
|
|||
*/
|
||||
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
||||
|
||||
/**
|
||||
* insert dimension at shape[axis] position
|
||||
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, dimension = 10 result is -> shape = {2,10,4,5}
|
||||
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 3, dimension = 10 result is -> shape = {2,4,5,10}
|
||||
* so be careful and provide shape buffer with enough (at least rank+1) length
|
||||
* axis should be within [0, rank] range
|
||||
*/
|
||||
ND4J_EXPORT _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension);
|
||||
|
||||
/**
|
||||
* erase dimension at shape[axis] position
|
||||
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, result is -> shape = {2,5}
|
||||
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 2, result is -> shape = {2,4}
|
||||
* axis should be within [0, rank-1] range
|
||||
*/
|
||||
ND4J_EXPORT _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis);
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -4932,21 +4915,6 @@ INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffs
|
|||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension) {
|
||||
|
||||
for (int i = rank; i > axis; --i)
|
||||
shape[i] = shape[i - 1];
|
||||
|
||||
shape[axis] = dimension;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis) {
|
||||
|
||||
for (int i = axis; i < rank - 1; ++i)
|
||||
shape[i] = shape[i + 1];
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -244,8 +244,9 @@ namespace functions {
|
|||
auto xi = x + threadOffset;
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
for (Nd4jLong i = 0; i < ulen; i++)
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
||||
}
|
||||
|
||||
PRAGMA_OMP_CRITICAL
|
||||
startingVal = OpType::update(startingVal, local, extraParams);
|
||||
|
|
|
@ -122,7 +122,7 @@ namespace functions {
|
|||
|
||||
tadLength = shape::length(tadOnlyShapeInfo);
|
||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||
numTads = shape::length(xShapeInfo) / tadLength;
|
||||
numTads = shape::length(yShapeInfo) / tadLength;
|
||||
xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||
}
|
||||
|
|
|
@ -21,12 +21,165 @@
|
|||
|
||||
#include <ops/specials_cuda.h>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int half = window>>1;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0) {
|
||||
xLength = shape::length(xShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
//for (int i = 0; i < length; i+= window)
|
||||
/*
|
||||
if window == 4;
|
||||
iterations will be: 0; 4; 8; 12; 16; 20
|
||||
if gridDim = 3;
|
||||
on first iteration we'll have: 0; 4; 8;
|
||||
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
|
||||
*/
|
||||
int firstPosition;
|
||||
int firstStep;
|
||||
int secondPosition;
|
||||
int secondStep;
|
||||
|
||||
int WARP_SIZE = 32;
|
||||
int numWarps = (gridDim.x * blockDim.x) / 32;
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int warpIdx = tid % WARP_SIZE;
|
||||
|
||||
if (half >= 128) {
|
||||
firstPosition = blockIdx.x * window;
|
||||
firstStep = gridDim.x * window;
|
||||
|
||||
secondPosition = threadIdx.x;
|
||||
secondStep = blockDim.x;
|
||||
} else if (half >= 32) {
|
||||
firstPosition = warpId * window;
|
||||
firstStep = numWarps * window;
|
||||
|
||||
secondPosition = warpIdx;
|
||||
secondStep = WARP_SIZE;
|
||||
} else {
|
||||
firstPosition = tid * window;
|
||||
firstStep = blockDim.x * gridDim.x * window;
|
||||
|
||||
secondPosition = 0;
|
||||
secondStep = 1;
|
||||
}
|
||||
|
||||
|
||||
for (int i = firstPosition; i < length; i += firstStep) {
|
||||
for (int j = secondPosition; j < half; j += secondStep) {
|
||||
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||
int ij = i+j;
|
||||
if (it < length && ij < length ) {
|
||||
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
|
||||
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
|
||||
|
||||
Y v0 = y[posIJ];
|
||||
Y v1 = y[posIT];
|
||||
|
||||
if(!descending == (v0 > v1)) {
|
||||
y[posIJ] = v1;
|
||||
y[posIT] = v0;
|
||||
|
||||
X xtemp = x[posIJ];
|
||||
x[posIJ] = x[posIT];
|
||||
x[posIT] = xtemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int half = window>>1;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0) {
|
||||
xLength = shape::length(xShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
//for (int i = 0; i < length; i+= window)
|
||||
/*
|
||||
if window == 4;
|
||||
iterations will be: 0; 4; 8; 12; 16; 20
|
||||
if gridDim = 3;
|
||||
on first iteration we'll have: 0; 4; 8;
|
||||
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
|
||||
*/
|
||||
int firstPosition;
|
||||
int firstStep;
|
||||
int secondPosition;
|
||||
int secondStep;
|
||||
|
||||
int WARP_SIZE = 32;
|
||||
int numWarps = (gridDim.x * blockDim.x) / 32;
|
||||
int warpId = tid / WARP_SIZE;
|
||||
int warpIdx = tid % WARP_SIZE;
|
||||
|
||||
if (half >= 128) {
|
||||
firstPosition = blockIdx.x * window;
|
||||
firstStep = gridDim.x * window;
|
||||
|
||||
secondPosition = threadIdx.x;
|
||||
secondStep = blockDim.x;
|
||||
} else if (half >= 32) {
|
||||
firstPosition = warpId * window;
|
||||
firstStep = numWarps * window;
|
||||
|
||||
secondPosition = warpIdx;
|
||||
secondStep = WARP_SIZE;
|
||||
} else {
|
||||
firstPosition = tid * window;
|
||||
firstStep = blockDim.x * gridDim.x * window;
|
||||
|
||||
secondPosition = 0;
|
||||
secondStep = 1;
|
||||
}
|
||||
|
||||
|
||||
for (int i = firstPosition; i < length; i += firstStep) {
|
||||
for (int j = secondPosition; j < half; j += secondStep) {
|
||||
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||
int ij = i+j;
|
||||
if (it < length && ij < length ) {
|
||||
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
|
||||
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
|
||||
|
||||
X v0 = x[posIJ];
|
||||
X v1 = x[posIT];
|
||||
|
||||
if(!descending == (v0 > v1)) {
|
||||
x[posIJ] = v1;
|
||||
x[posIT] = v0;
|
||||
|
||||
Y ytemp = y[posIJ];
|
||||
y[posIJ] = y[posIT];
|
||||
y[posIT] = ytemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__device__
|
||||
void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
|
||||
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
auto x = static_cast<T*>(vx);
|
||||
|
||||
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
|
@ -85,8 +238,8 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
|
|||
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||
int ij = i+j;
|
||||
if (it < length && ij < length ) {
|
||||
int posIT = getDevicePosition(xShapeInfo,it, xLength);
|
||||
int posIJ = getDevicePosition(xShapeInfo, ij, xLength);
|
||||
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
|
||||
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
|
||||
|
||||
shmem[threadIdx.x] = x[posIJ];
|
||||
shmem[threadIdx.x + blockDim.x] = x[posIT];
|
||||
|
@ -100,18 +253,22 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
|
|||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
|
||||
bitonicArbitraryStepKernel<T>(vx, xShapeInfo, window, length, reverse, descending);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
|
||||
execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending);
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "bitonicArbitrary(...) failed");
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -21,9 +21,119 @@
|
|||
|
||||
#include <ops/specials_cuda.h>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
unsigned int i, ixj; /* Sorting partners: i and ixj */
|
||||
i = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0)
|
||||
xLength = shape::length(xShapeInfo);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
if (i >= length)
|
||||
return;
|
||||
|
||||
ixj = i^j;
|
||||
|
||||
/* The threads with the lowest ids sort the array. */
|
||||
if ((ixj)>i) {
|
||||
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
|
||||
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
|
||||
|
||||
if ((i&k)==0) {
|
||||
/* Sort ascending */
|
||||
if (!descending == (y[posI]>y[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
} else if ((i&k)!=0) {
|
||||
/* Sort descending */
|
||||
if (!descending == (y[posI]<y[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
unsigned int i, ixj; /* Sorting partners: i and ixj */
|
||||
i = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
|
||||
__shared__ Nd4jLong xLength;
|
||||
if (threadIdx.x == 0)
|
||||
xLength = shape::length(xShapeInfo);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
if (i >= length)
|
||||
return;
|
||||
|
||||
ixj = i^j;
|
||||
|
||||
/* The threads with the lowest ids sort the array. */
|
||||
if ((ixj)>i) {
|
||||
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
|
||||
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
|
||||
|
||||
if ((i&k)==0) {
|
||||
/* Sort ascending */
|
||||
if (!descending == (x[posI]>x[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
} else if ((i&k)!=0) {
|
||||
/* Sort descending */
|
||||
if (!descending == (x[posI]<x[posIXJ])) {
|
||||
/* exchange(i,ixj); */
|
||||
X temp = x[posI];
|
||||
x[posI] = x[posIXJ];
|
||||
x[posIXJ] = temp;
|
||||
|
||||
Y ytemp = y[posI];
|
||||
y[posI] = y[posIXJ];
|
||||
y[posIXJ] = ytemp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||
__global__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
auto x = static_cast<T*>(vx);
|
||||
|
||||
|
@ -44,8 +154,8 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
|
|||
|
||||
/* The threads with the lowest ids sort the array. */
|
||||
if ((ixj)>i) {
|
||||
int posI = getDevicePosition(xShapeInfo, i, xLength);
|
||||
int posIXJ = getDevicePosition(xShapeInfo, ixj, xLength);
|
||||
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
|
||||
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
|
||||
|
||||
if ((i&k)==0) {
|
||||
/* Sort ascending */
|
||||
|
@ -69,16 +179,23 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void execBitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
bitonicSortStepKernel<T>(vx, xShapeInfo, j, k, length, descending);
|
||||
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||
bitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||
|
||||
execBitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "bitonicSortStep(...) failed");
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -16,18 +16,89 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma, created on 28.11.2018
|
||||
//
|
||||
|
||||
#include <ops/specials_cuda.h>
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
__global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
|
||||
auto x = static_cast<X*>(vx);
|
||||
auto y = static_cast<Y*>(vy);
|
||||
|
||||
__shared__ int xLength;
|
||||
__shared__ int xTadLength;
|
||||
__shared__ int numTads;
|
||||
if (threadIdx.x == 0) {
|
||||
xLength = shape::length(xShapeInfo);
|
||||
xTadLength = shape::length(tadShapeInfo);
|
||||
numTads = xLength / xTadLength;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||
auto dx = x + tadOffsets[r];
|
||||
auto dy = y + tadOffsets[r];
|
||||
|
||||
// this is general loop, we go uncached
|
||||
int iterations = xTadLength;
|
||||
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
|
||||
if (i % 2 == 0) {
|
||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 1;
|
||||
if (top < xTadLength) {
|
||||
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||
|
||||
if (!descending == (dx[t0] > dx[t1])) {
|
||||
X dt0 = dx[t0];
|
||||
dx[t0] = dx[t1];
|
||||
dx[t1] = dt0;
|
||||
|
||||
Y dy0 = dy[t0];
|
||||
dy[t0] = dy[t1];
|
||||
dy[t1] = dy0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 2;
|
||||
if (top < xTadLength) {
|
||||
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||
|
||||
if (!descending == (dx[t0] > dx[t1])) {
|
||||
X dt0 = dx[t0];
|
||||
dx[t0] = dx[t1];
|
||||
dx[t1] = dt0;
|
||||
|
||||
Y dy0 = dy[t0];
|
||||
dy[t0] = dy[t1];
|
||||
dy[t1] = dy0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__device__
|
||||
void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
|
||||
auto x = static_cast<T*>(vx);
|
||||
const int sharedSize = 32768;
|
||||
|
@ -56,7 +127,7 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|||
int iterations = xTadLength;
|
||||
if (cached) {
|
||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
|
||||
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
|
||||
shmem[tid] = dx[t0];
|
||||
}
|
||||
|
||||
|
@ -70,8 +141,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 1;
|
||||
if (top < xTadLength) {
|
||||
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
|
||||
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
|
||||
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||
|
||||
if (!descending == (dx[t0] > dx[t1])) {
|
||||
T dt0 = dx[t0];
|
||||
|
@ -84,8 +155,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 2;
|
||||
if (top < xTadLength) {
|
||||
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
|
||||
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
|
||||
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||
|
||||
if (!descending == (dx[t0] > dx[t1])) {
|
||||
T dt0 = dx[t0];
|
||||
|
@ -102,32 +173,34 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|||
if (cached) {
|
||||
dx = x + tadOffsets[r];
|
||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
|
||||
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
|
||||
dx[t0] = shmem[tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
|
||||
oesTadKernel<T>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
|
||||
execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "oesTad(...) failed");
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
bool descending) {
|
||||
|
||||
execOesTadKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
|
|
|
@ -37,7 +37,7 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
|
|||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
std::vector<int> sharedAxes = *block.getIArguments();
|
||||
|
||||
|
||||
const int inputRank = input->rankOf();
|
||||
const int alphaRank = alpha->rankOf();
|
||||
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
||||
|
@ -49,12 +49,12 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
|
|||
//***** input validation *****//
|
||||
std::vector<Nd4jLong> expectedAlphaShape(&inputShape[1], &inputShape[inputRank]);
|
||||
|
||||
REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
|
||||
|
||||
REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
|
||||
|
||||
for(int i = 0; i < numSharedAxes; ++i) {
|
||||
if(sharedAxes[i] <= 0)
|
||||
sharedAxes[i] += inputRank - 1;
|
||||
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
|
||||
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
|
||||
expectedAlphaShape[sharedAxes[i] - 1] = 1;
|
||||
}
|
||||
|
||||
|
@ -65,14 +65,8 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
|
|||
REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
||||
// ***** end of validation ***** //
|
||||
|
||||
if(alphaShape != expectedAlphaShape)
|
||||
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
|
||||
helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output);
|
||||
|
||||
helpers::prelu(block.launchContext(), *input, *alpha, *output);
|
||||
|
||||
if(alphaShape != expectedAlphaShape)
|
||||
delete alpha;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -90,12 +84,12 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
|
|||
auto input = INPUT_VARIABLE(0);
|
||||
auto alpha = INPUT_VARIABLE(1);
|
||||
auto dLdO = INPUT_VARIABLE(2);
|
||||
|
||||
|
||||
auto dLdI = OUTPUT_VARIABLE(0);
|
||||
auto dLdA = OUTPUT_VARIABLE(1);
|
||||
|
||||
std::vector<int> sharedAxes = *block.getIArguments();
|
||||
|
||||
|
||||
const int inputRank = input->rankOf();
|
||||
const int alphaRank = alpha->rankOf();
|
||||
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
||||
|
@ -105,19 +99,19 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
|
|||
const std::vector<Nd4jLong> alphaShape = alpha->getShapeAsVector();
|
||||
|
||||
//***** input validation *****//
|
||||
|
||||
|
||||
// temporary limitation imposed by Yurii
|
||||
REQUIRE_TRUE(inputRank <= MAX_RANK/2, 0, "rank of input array should be <= MAX_RANK/2, but got %i instead!", inputRank);
|
||||
REQUIRE_TRUE(input->lengthOf() / alpha->lengthOf() <= MAX_RANK*2, 0, "the length of input array should be no more than MAX_RANK*2 times the alpha array length, but got %lld and %lld correspondingly!", input->lengthOf(), alpha->lengthOf());
|
||||
|
||||
std::vector<Nd4jLong> expectedAlphaShape(&inputShape[1], &inputShape[inputRank]);
|
||||
|
||||
REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
|
||||
|
||||
REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
|
||||
|
||||
for(int i = 0; i < numSharedAxes; ++i) {
|
||||
if(sharedAxes[i] <= 0)
|
||||
sharedAxes[i] += inputRank - 1;
|
||||
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
|
||||
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
|
||||
expectedAlphaShape[sharedAxes[i] - 1] = 1;
|
||||
}
|
||||
|
||||
|
@ -127,19 +121,20 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
|
|||
|
||||
REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
||||
// ***** end of validation ***** //
|
||||
|
||||
|
||||
|
||||
if(alphaShape != expectedAlphaShape) {
|
||||
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
|
||||
dLdA = dLdA->reshape(dLdA->ordering(), expectedAlphaShape);
|
||||
alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape));
|
||||
dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape));
|
||||
}
|
||||
|
||||
helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA);
|
||||
|
||||
if(alphaShape != expectedAlphaShape) {
|
||||
if(alphaShape != expectedAlphaShape) {
|
||||
delete alpha;
|
||||
delete dLdA;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ namespace nd4j {
|
|||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
nd4j_printf("Comparing [%f] to [%f]\n", x->e<float>(0), y->e<float>(0));
|
||||
if (x->e<float>(0) < y->e<float>(0))
|
||||
return ND4J_STATUS_TRUE;
|
||||
else
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace nd4j {
|
|||
auto condition = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
if (z->isEmpty())
|
||||
return ND4J_STATUS_OK;
|
||||
return Status::OK();
|
||||
|
||||
if (block.width() == 3) {
|
||||
auto x = INPUT_VARIABLE(1);
|
||||
|
@ -44,12 +44,10 @@ namespace nd4j {
|
|||
// FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y
|
||||
for (int e = 0; e < condition->lengthOf(); e++) {
|
||||
if (y->isR()) {
|
||||
auto r = !condition->e<bool>(e) ? y->e<double>(e)
|
||||
: x->e<double>(e);
|
||||
auto r = !condition->e<bool>(e) ? y->e<double>(e) : x->e<double>(e);
|
||||
z->p(e, r);
|
||||
} else {
|
||||
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e)
|
||||
: x->e<Nd4jLong>(e);
|
||||
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e) : x->e<Nd4jLong>(e);
|
||||
z->p(e, r);
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +84,7 @@ namespace nd4j {
|
|||
|
||||
helpers::_where(block.launchContext(), *condition, *output, block.workspace());
|
||||
}
|
||||
return ND4J_STATUS_OK;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(Where) {
|
||||
|
|
|
@ -120,7 +120,7 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(where_np) {
|
||||
|
|
|
@ -81,11 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
|||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
delete inputReshaped;
|
||||
delete outputReshaped;
|
||||
delete weightsReshaped;
|
||||
ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -217,13 +213,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
|||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
delete inputReshaped;
|
||||
delete gradIReshaped;
|
||||
delete gradOReshaped;
|
||||
delete weightsReshaped;
|
||||
delete gradWReshaped;
|
||||
ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ using namespace mkldnn;
|
|||
#endif
|
||||
|
||||
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -42,7 +42,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
|
@ -151,10 +151,10 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
|
||||
std::vector<int> permutForOutput;
|
||||
|
||||
if(!isNCDHW)
|
||||
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
else
|
||||
if (isNCDHW)
|
||||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||
else
|
||||
input = new NDArray(input->permute({0,4,1,2,3}));
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
|
@ -164,9 +164,9 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
if(bias)
|
||||
output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
|
||||
|
||||
if(!isNCDHW)
|
||||
delete input;
|
||||
|
||||
if(!isNCDHW)
|
||||
delete input;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -202,36 +202,36 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
const int rank = 5;
|
||||
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
|
||||
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
|
||||
|
||||
|
||||
int indIOioC, indIiD, indWoC(4);
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
else {
|
||||
else {
|
||||
indIOioC = 1; indIiD = 2;
|
||||
}
|
||||
}
|
||||
|
||||
int bS = inputShapeInfo[1]; // batch size
|
||||
int iD = inputShapeInfo[indIiD+1]; // input depth
|
||||
int iH = inputShapeInfo[indIiD+2]; // input height
|
||||
int iW = inputShapeInfo[indIiD+3]; // input width
|
||||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
int oD, oH, oW; // output depth, height, width
|
||||
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
|
||||
Nd4jLong* outputShapeInfo = nullptr;
|
||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
|
||||
|
||||
outputShapeInfo[0] = rank;
|
||||
outputShapeInfo[1] = bS;
|
||||
if (isNCDHW) {
|
||||
if (isNCDHW) {
|
||||
outputShapeInfo[2] = oC;
|
||||
outputShapeInfo[3] = oD;
|
||||
outputShapeInfo[4] = oH;
|
||||
|
@ -242,7 +242,7 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
outputShapeInfo[4] = oW;
|
||||
outputShapeInfo[5] = oC;
|
||||
}
|
||||
|
||||
|
||||
ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo));
|
||||
|
||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
||||
|
@ -251,12 +251,12 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -291,12 +291,12 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if(isSameMode) // SAME
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
|
||||
#ifdef HAVE_MKLDNN
|
||||
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB})) {
|
||||
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
|
||||
|
@ -447,35 +447,37 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
std::vector<int> gradOaxesForDot;
|
||||
|
||||
if(!isNDHWC) {
|
||||
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = gradI->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW
|
||||
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
}
|
||||
else
|
||||
else {
|
||||
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
|
||||
}
|
||||
|
||||
// ----- calculation of gradW and gradB ----- //
|
||||
// ----- calculation of gradW and gradB ----- //
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
||||
//----- calculation of gradO -----//
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
}
|
||||
|
||||
//----- calculation of gradI -----//
|
||||
//----- calculation of gradI -----//
|
||||
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||
|
||||
|
||||
if(!isNDHWC) {
|
||||
delete input;
|
||||
delete input;
|
||||
delete gradI;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -520,15 +522,15 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
if(!isNDHWC) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
else {
|
||||
else {
|
||||
indIOioC = 1; indIiD = 2;
|
||||
}
|
||||
}
|
||||
|
||||
int bS = inputShapeInfo[1]; // batch size
|
||||
int iD = inputShapeInfo[indIiD+1]; // input depth
|
||||
int iH = inputShapeInfo[indIiD+2]; // input height
|
||||
int iW = inputShapeInfo[indIiD+3]; // input width
|
||||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
|
@ -538,7 +540,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||
|
@ -547,7 +549,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
if(biasShapeInfo) {
|
||||
auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo));
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if(!isNCHW)
|
||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
@ -77,14 +77,14 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
|
||||
LaunchContext* ctx = block.launchContext();
|
||||
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
|
||||
|
||||
|
||||
//----- add biases if required -----//
|
||||
if(bias)
|
||||
output->applyBroadcast(broadcast::Add, {1}, bias);
|
||||
|
||||
if(!isNCHW)
|
||||
if(!isNCHW)
|
||||
delete output;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_TYPES(deconv2d) {
|
||||
|
@ -135,7 +135,7 @@ DECLARE_SHAPE_FN(deconv2d) {
|
|||
|
||||
int oH, oW; // output height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
|
||||
Nd4jLong outputShape[4];
|
||||
|
||||
outputShape[0] = bS;
|
||||
|
@ -211,8 +211,9 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
|
||||
// -----prepare permutation arrays and axes for dot product ----- //
|
||||
std::vector<int> inputAxesForDot;
|
||||
|
||||
if(!isNCHW) {
|
||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
inputAxesForDot = {0, 1, 2}; // bS, iH, iW
|
||||
}
|
||||
else
|
||||
|
@ -228,7 +229,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
@ -237,7 +238,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
if(!isNCHW)
|
||||
delete gradO;
|
||||
|
||||
return ND4J_STATUS_OK;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(deconv2d_bp) {
|
||||
|
|
|
@ -27,32 +27,32 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
|
||||
CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) depth
|
||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) height
|
||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2)); // filter(kernel) width
|
||||
int sD = INT_ARG(3); // strides depth
|
||||
int sH = INT_ARG(4); // strides height
|
||||
int sW = INT_ARG(5); // strides width
|
||||
int pD = INT_ARG(6); // paddings depth
|
||||
int pH = INT_ARG(7); // paddings height
|
||||
int pW = INT_ARG(8); // paddings width
|
||||
int dD = INT_ARG(9); // dilations depth
|
||||
int dH = INT_ARG(10); // dilations height
|
||||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
|
||||
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
|
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
if(!isNCDHW)
|
||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
@ -76,14 +76,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
||||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||
|
||||
|
||||
//----- add biases if required -----//
|
||||
if(bias)
|
||||
output->applyBroadcast(broadcast::Add,{1}, bias);
|
||||
|
||||
if(!isNCDHW)
|
||||
delete output;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
@ -123,17 +123,17 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
|
||||
int indIOioC, indIiD, indWoC(3);
|
||||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
else {
|
||||
else {
|
||||
indIOioC = 1; indIiD = 2;
|
||||
}
|
||||
}
|
||||
|
||||
const int bS = inputShapeInfo[1]; // batch size
|
||||
const int iD = inputShapeInfo[indIiD+1]; // input depth
|
||||
const int iH = inputShapeInfo[indIiD+2]; // input height
|
||||
const int iW = inputShapeInfo[indIiD+3]; // input width
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
|
||||
|
@ -143,7 +143,7 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
|
||||
int oD, oH, oW; // output depth, height, width
|
||||
ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
|
||||
Nd4jLong* outputShapeInfo = nullptr;
|
||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
|
||||
|
||||
|
@ -161,7 +161,7 @@ DECLARE_SHAPE_FN(deconv3d) {
|
|||
outputShapeInfo[4] = oW;
|
||||
outputShapeInfo[5] = oC;
|
||||
}
|
||||
|
||||
|
||||
ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo));
|
||||
|
||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
||||
|
@ -225,8 +225,9 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
|
||||
// -----prepare permutation arrays and axes for dot product ----- //
|
||||
std::vector<int> inputAxesForDot;
|
||||
|
||||
if(!isNCDHW) {
|
||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||
inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW
|
||||
}
|
||||
else
|
||||
|
@ -240,7 +241,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
@ -260,7 +261,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
->setAllowedInputTypes(3, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(deconv3d_bp) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
|
@ -292,15 +293,15 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
|||
if(!isNCDHW) {
|
||||
indIOioC = 4; indIiD = 1;
|
||||
}
|
||||
else {
|
||||
else {
|
||||
indIOioC = 1; indIiD = 2;
|
||||
}
|
||||
}
|
||||
|
||||
const int bS = inputShapeInfo[1]; // batch size
|
||||
const int iD = inputShapeInfo[indIiD+1]; // input depth
|
||||
const int iH = inputShapeInfo[indIiD+2]; // input height
|
||||
const int iW = inputShapeInfo[indIiD+3]; // input width
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
int trueoD, trueoH, trueoW; // true output depth, height, width
|
||||
|
@ -312,7 +313,7 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
|
|||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
||||
auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||
auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace());
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ namespace ops {
|
|||
int pad_top = 0, pad_left = 0;
|
||||
int out_rows = 0, out_cols = 0;
|
||||
|
||||
helpers::_dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||
|
||||
|
||||
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
|
||||
|
@ -112,7 +112,7 @@ namespace ops {
|
|||
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(block.dataType());
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
||||
|
||||
int e = 1;
|
||||
for (int cnt = 0;cnt < 4; cnt++)
|
||||
rates[cnt] = INT_ARG(e++);
|
||||
|
@ -126,7 +126,7 @@ namespace ops {
|
|||
int pad_top = 0, pad_left = 0;
|
||||
int out_rows = 0, out_cols = 0;
|
||||
|
||||
helpers::_dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||
|
||||
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
|
||||
|
|
|
@ -59,21 +59,20 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
|||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||
|
||||
if (!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
if (isSameMode)
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
||||
//output->printBuffer("output op");
|
||||
|
||||
if (!isNCHW) {
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
@ -92,7 +91,7 @@ DECLARE_SYN(avgpool, avgpool2d);
|
|||
}
|
||||
|
||||
DECLARE_SHAPE_FN(avgpool2d) {
|
||||
|
||||
|
||||
auto inShape = inputShape->at(0);
|
||||
auto shapeOf = shape::shapeOf(inShape);
|
||||
|
||||
|
@ -177,27 +176,28 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
|
||||
if(!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
|
||||
|
||||
|
||||
// columns2d->addiColumnVector(gradOVector);
|
||||
|
||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||
|
||||
// *gradI /= kH*kW;
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
// *gradI /= kH*kW;
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCHW) {
|
||||
|
@ -205,16 +205,13 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
// delete columns;
|
||||
// delete columns2d;
|
||||
// delete gradOVector;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(avgpool2d_bp) {
|
||||
|
||||
|
||||
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
|
||||
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "AVGPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);
|
||||
|
||||
|
|
|
@ -30,10 +30,10 @@ namespace ops {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
|
@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
int extraParam0 = INT_ARG(13);
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
|
@ -61,21 +61,21 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
//T extraParams[] = {};
|
||||
|
||||
//T extraParams[] = {};
|
||||
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCDHW) {
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -103,22 +103,22 @@ DECLARE_SHAPE_FN(avgpool3dnew) {
|
|||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
|
||||
auto inputShapeInfo = inputShape->at(0);
|
||||
|
||||
int idxID, idxIC;
|
||||
int idxID, idxIC;
|
||||
if(isNCDHW) { idxID = 2; idxIC = 1;}
|
||||
else { idxID = 1; idxIC = 4;}
|
||||
|
||||
int bS = inputShapeInfo[1]; // batch size
|
||||
int iC = inputShapeInfo[idxIC+1]; // input channels
|
||||
int iC = inputShapeInfo[idxIC+1]; // input channels
|
||||
int iD = inputShapeInfo[idxID+1]; // input depth
|
||||
int iH = inputShapeInfo[idxID+2]; // input height
|
||||
int iW = inputShapeInfo[idxID+3]; // input width
|
||||
|
||||
int oD, oH, oW; // output depth, height, width
|
||||
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
|
||||
Nd4jLong outputShape[5];
|
||||
|
||||
outputShape[0] = bS;
|
||||
|
@ -146,7 +146,7 @@ DECLARE_SHAPE_FN(avgpool3dnew) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||
|
@ -164,10 +164,10 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
const int dH = INT_ARG(10); // dilations height
|
||||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
|
@ -180,22 +180,22 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -59,10 +59,10 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
|||
const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1);
|
||||
const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2);
|
||||
|
||||
if (!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
|
@ -71,8 +71,8 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
|||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
||||
|
||||
if (!isNCHW) {
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
@ -92,7 +92,7 @@ DECLARE_SYN(maxpool, maxpool2d);
|
|||
|
||||
|
||||
DECLARE_SHAPE_FN(maxpool2d) {
|
||||
|
||||
|
||||
//NDArray<T> *x = block.getVariables().at(0)->getNDArray();
|
||||
Nd4jLong* inShape = inputShape->at(0);
|
||||
Nd4jLong* shapeOf = shape::shapeOf(inShape);
|
||||
|
@ -120,7 +120,7 @@ DECLARE_SHAPE_FN(maxpool2d) {
|
|||
// calculate output Height/Width
|
||||
int oH, oW;
|
||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
|
||||
// allocate memory for new shape
|
||||
Nd4jLong newShape[4];
|
||||
|
||||
|
@ -175,27 +175,27 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
|||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
|
||||
|
||||
|
||||
// input->template applyTransform<simdOps::Im2col<T>>(columns, std::vector<T>({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data());
|
||||
|
||||
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
|
||||
// columns2d->template applyTransform<simdOps::IsMax<T>>(std::vector<T>({(T)1., (T)1.}).data());
|
||||
// columns2d->muliColumnVector(gradOVector);
|
||||
|
||||
|
||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||
|
||||
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
|
||||
|
||||
if(!isNCHW) {
|
||||
|
@ -203,17 +203,14 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
|||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
// delete columns;
|
||||
// delete columns2d;
|
||||
// delete gradOVector;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(MaxPool2D_bp, maxpool2d_bp);
|
||||
DECLARE_SYN(MaxPool_bp, maxpool2d_bp);
|
||||
|
||||
DECLARE_SHAPE_FN(maxpool2d_bp) {
|
||||
|
||||
|
||||
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
|
||||
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "MAXPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);
|
||||
|
||||
|
|
|
@ -30,10 +30,10 @@ namespace ops {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
||||
|
||||
|
||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||
|
||||
|
||||
int kD = INT_ARG(0); // filter(kernel) depth
|
||||
int kH = INT_ARG(1); // filter(kernel) height
|
||||
int kW = INT_ARG(2); // filter(kernel) width
|
||||
|
@ -48,9 +48,9 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
|
@ -59,24 +59,24 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
|||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
|
||||
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||
|
||||
if(!isNCDHW) {
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -102,25 +102,25 @@ DECLARE_SHAPE_FN(maxpool3dnew) {
|
|||
int dW = INT_ARG(11); // dilations width
|
||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13);
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
Nd4jLong* inputShapeInfo = inputShape->at(0);
|
||||
|
||||
int idxID, idxIC;
|
||||
int idxID, idxIC;
|
||||
if(isNCDHW) { idxID = 2; idxIC = 1;}
|
||||
else { idxID = 1; idxIC = 4;}
|
||||
|
||||
int bS = inputShapeInfo[1]; // batch size
|
||||
int iC = inputShapeInfo[idxIC+1]; // input channels
|
||||
int iC = inputShapeInfo[idxIC+1]; // input channels
|
||||
int iD = inputShapeInfo[idxID+1]; // input depth
|
||||
int iH = inputShapeInfo[idxID+2]; // input height
|
||||
int iW = inputShapeInfo[idxID+3]; // input width
|
||||
|
||||
int oD, oH, oW; // output depth, height, width
|
||||
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
|
||||
|
||||
|
||||
Nd4jLong outputShape[5];
|
||||
|
||||
|
||||
|
@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
const int dW = INT_ARG(11); // dilations width
|
||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||
|
||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||
|
@ -182,21 +182,21 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||
}
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace());
|
||||
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
|
||||
// ConvolutionUtils<T>::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
// ConvolutionUtils<T>::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
|
||||
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, kD*kH*kW});
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// T extraParams[] = {(T)1., (T)1.};
|
||||
// columns2d->template applyTransform<simdOps::IsMax<T>>(extraParams);
|
||||
// columns2d->muliColumnVector(gradOVector);
|
||||
|
@ -211,10 +211,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
// delete columns;
|
||||
// delete columns2d;
|
||||
// delete gradOVector;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -52,11 +52,11 @@ namespace nd4j {
|
|||
int oY = 0;
|
||||
int oX = 0;
|
||||
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
|
||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
if (!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
const auto inY = static_cast<int>(input->sizeAt(2));
|
||||
|
@ -70,7 +70,7 @@ namespace nd4j {
|
|||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
||||
|
||||
if (!isNCHW) {
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
delete output;
|
||||
}
|
||||
|
@ -175,40 +175,40 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
|||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCHW) {
|
||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||
}
|
||||
|
||||
// if(isSameMode) // SAME
|
||||
|
||||
// if(isSameMode) // SAME
|
||||
// ConvolutionUtils<T>::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
|
||||
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
|
||||
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
|
||||
// NDArray<T> pNorm(columns2d->getShapeInfo(), block.getWorkspace());
|
||||
// NDArray<T> pNorm(columns2d->getShapeInfo(), block.getWorkspace());
|
||||
|
||||
// input->template applyTransform<simdOps::Im2col<T>>(columns, std::vector<T>({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data());
|
||||
|
||||
|
||||
// columns2d->template applyTransform<simdOps::Abs<T>>(&pNorm);
|
||||
// pNorm.template applyTransform<simdOps::Pow<T>>(&pNorm, std::vector<T>({(T)pnorm}).data());
|
||||
|
||||
// NDArray<T>* denomVec = pNorm.sum({1});
|
||||
// denomVec->template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)1. / pnorm}).data());
|
||||
// denomVec->template applyScalar<simdOps::Max<T>>(eps); // in case of 0
|
||||
// NDArray<T>* denomVec = pNorm.sum({1});
|
||||
// denomVec->template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)1. / pnorm}).data());
|
||||
// denomVec->template applyScalar<simdOps::Max<T>>(eps); // in case of 0
|
||||
// denomVec->template applyPairwiseTransform<simdOps::ReverseDivide<T>>(gradOVector, denomVec, nullptr);
|
||||
|
||||
// if(pnorm != 2) {
|
||||
// T extraParams[] = {(T)1. - (T)2. / pnorm};
|
||||
// pNorm.template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)2. / pnorm}).data());
|
||||
// *columns2d *= pNorm;
|
||||
// }
|
||||
|
||||
// }
|
||||
|
||||
// columns2d->muliColumnVector(denomVec);
|
||||
|
||||
|
||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||
|
||||
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
|
||||
|
||||
if(!isNCHW) {
|
||||
|
@ -216,16 +216,12 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
|||
delete gradI;
|
||||
delete gradO;
|
||||
}
|
||||
// delete columns;
|
||||
// delete columns2d;
|
||||
// delete gradOVector;
|
||||
// delete denomVec;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(pnormpool2d_bp) {
|
||||
|
||||
|
||||
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
|
||||
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace ops {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
||||
|
||||
|
||||
auto logits = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto labels = INPUT_VARIABLE(2);
|
||||
|
@ -37,17 +37,17 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
|||
|
||||
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
|
||||
double labelsSmoothing = T_ARG(0);
|
||||
|
||||
// input validation
|
||||
|
||||
// input validation
|
||||
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
|
||||
// only 4 possible reduction modes exist
|
||||
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
|
||||
// smoothing is possible for rank of logits/labels > 1
|
||||
REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of labels/ logits = 1 !");
|
||||
|
||||
|
||||
if(!output->isScalar()) {
|
||||
// weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape
|
||||
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf());
|
||||
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf());
|
||||
// check whether broadcast operation is possible for weights array
|
||||
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
|
||||
}
|
||||
|
@ -59,8 +59,8 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
|||
if(labelsSmoothing != 0.) {
|
||||
newLabels = new NDArray(cLabels);
|
||||
*newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension
|
||||
// softmax_i = exp(logits_i) / sum_j(exp(logits_j))
|
||||
// so result = sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) )
|
||||
|
@ -73,24 +73,24 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
|||
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true);
|
||||
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log);
|
||||
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions);
|
||||
|
||||
|
||||
// perform weights broadcasting/tile to E if it is necessary
|
||||
auto weightsBroad = weights;
|
||||
if(!weights->isScalar() && !weights->isSameShape(&E)) {
|
||||
if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1)
|
||||
weightsBroad = weights->reshape(weights->ordering(), {weights->lengthOf()});
|
||||
weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()}));
|
||||
else
|
||||
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
|
||||
}
|
||||
|
||||
// multiply E on weights
|
||||
|
||||
// multiply E on weights
|
||||
E *= *weightsBroad;
|
||||
|
||||
switch (reductionMode) {
|
||||
case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels.
|
||||
output->assign(&E);
|
||||
break;
|
||||
|
||||
|
||||
case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array
|
||||
E.reduceNumber(reduce::Sum, *output);
|
||||
break;
|
||||
|
@ -99,12 +99,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
|||
double sum;
|
||||
if (weights->isScalar())
|
||||
sum = weights->e<double>(0) * E.lengthOf();
|
||||
else
|
||||
else
|
||||
sum = weightsBroad->reduceNumber(reduce::Sum).e<double>(0);
|
||||
|
||||
|
||||
if (sum == 0.)
|
||||
*output = 0.;
|
||||
else
|
||||
else
|
||||
output->assign(E.reduceNumber(reduce::Sum) / sum);
|
||||
break;
|
||||
}
|
||||
|
@ -132,15 +132,15 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
|||
|
||||
if(newLabels != cLabels)
|
||||
delete newLabels;
|
||||
|
||||
|
||||
delete cLabels;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(softmax_cross_entropy_loss) {
|
||||
|
||||
|
||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS})
|
||||
|
@ -149,12 +149,12 @@ DECLARE_TYPES(softmax_cross_entropy_loss) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
|
||||
|
||||
|
||||
auto logitsShapeInfo = inputShape->at(0);
|
||||
auto weightsShapeInfo = inputShape->at(1);
|
||||
auto labelsShapeInfo = inputShape->at(2);
|
||||
|
||||
// labels and logits must have the same shapes
|
||||
// labels and logits must have the same shapes
|
||||
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
||||
|
||||
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
|
||||
|
@ -165,14 +165,14 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
|
|||
else { // in this case output has the shape as labels and logits minus last dimension
|
||||
std::vector<int> dimensions = {-1};
|
||||
outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.getWorkspace());
|
||||
|
||||
|
||||
// weights array can be single scalar or has the same rank as output, and must be broadcastable to output
|
||||
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo));
|
||||
// check whether broadcast operation is possible for weights array
|
||||
// check whether broadcast operation is possible for weights array
|
||||
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str());
|
||||
}
|
||||
|
||||
return SHAPELIST(outShapeInfo);
|
||||
|
||||
return SHAPELIST(outShapeInfo);
|
||||
}
|
||||
|
||||
|
||||
|
@ -185,15 +185,15 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
||||
|
||||
|
||||
auto logits = INPUT_VARIABLE(0);
|
||||
auto weights = INPUT_VARIABLE(1);
|
||||
auto labels = INPUT_VARIABLE(2);
|
||||
|
||||
|
||||
auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits
|
||||
auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights
|
||||
auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels
|
||||
|
||||
|
||||
auto labelsSmoothing = T_ARG(0);
|
||||
|
||||
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
|
||||
|
@ -203,13 +203,13 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
|||
|
||||
std::vector<int> dimensions = {-1};
|
||||
|
||||
// input validation
|
||||
// input validation
|
||||
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
|
||||
// only 4 possible reduction modes exist
|
||||
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
|
||||
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
|
||||
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->getShapeInfo(), false, false, block.getWorkspace());
|
||||
// weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape
|
||||
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo));
|
||||
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo));
|
||||
// check whether broadcast operation is possible for weights array
|
||||
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->getShapeInfo(), lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
|
||||
// smoothing is possible for rank of logits/labels > 1
|
||||
|
@ -221,14 +221,14 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
|||
auto newLabels = cLabels;
|
||||
if(labelsSmoothing != 0.) {
|
||||
newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext());
|
||||
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
|
||||
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
|
||||
}
|
||||
|
||||
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp);
|
||||
softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true);
|
||||
|
||||
// dEdp = softmax * sum_i(lables_i) - labels
|
||||
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels);
|
||||
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels);
|
||||
|
||||
// dEdl = -log(softmax)
|
||||
dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing));
|
||||
|
@ -236,11 +236,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
|||
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true);
|
||||
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log);
|
||||
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions);
|
||||
|
||||
|
||||
// perform weights broadcasting/tile to E if it is necessary
|
||||
auto weightsBroad = weights;
|
||||
if(!weights->isScalar() && !weights->isSameShape(&E))
|
||||
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
|
||||
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
|
||||
|
||||
dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions);
|
||||
|
||||
|
@ -344,18 +344,18 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
|
|||
|
||||
if(weightsBroad != weights)
|
||||
delete weightsBroad;
|
||||
|
||||
|
||||
if(newLabels != cLabels)
|
||||
delete newLabels;
|
||||
delete newLabels;
|
||||
|
||||
delete cLabels;
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(softmax_cross_entropy_loss_grad) {
|
||||
|
||||
|
||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS})
|
||||
|
@ -367,27 +367,27 @@ DECLARE_TYPES(softmax_cross_entropy_loss_grad) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(softmax_cross_entropy_loss_grad) {
|
||||
|
||||
|
||||
auto logitsShapeInfo = inputShape->at(0);
|
||||
auto weightsShapeInfo = inputShape->at(1);
|
||||
auto labelsShapeInfo = inputShape->at(2);
|
||||
|
||||
std::vector<int> dimensions = {-1};
|
||||
|
||||
// labels and logits must have the same shapes
|
||||
// labels and logits must have the same shapes
|
||||
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
||||
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace());
|
||||
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace());
|
||||
// weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss
|
||||
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo));
|
||||
// check whether broadcast operation is possible for weights array
|
||||
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
|
||||
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
|
||||
|
||||
auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
|
||||
|
||||
auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo)));
|
||||
auto dLdwShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(weightsShapeInfo), shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo)));
|
||||
auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
|
||||
|
||||
|
||||
return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo);
|
||||
}
|
||||
|
||||
|
|
|
@ -74,7 +74,7 @@ namespace ops {
|
|||
}
|
||||
|
||||
if(mask != nullptr){
|
||||
NDArray* reshapedMask;
|
||||
NDArray reshapedMask;
|
||||
if(weights->rankOf() == 4){
|
||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
||||
}else{
|
||||
|
@ -87,8 +87,7 @@ namespace ops {
|
|||
// before going through the softmax, we effectively push all masked positions to zero after softmax.
|
||||
//
|
||||
// we are using 1e9 to mean effectively infinity
|
||||
*weights += (*reshapedMask - 1) * 1e9;
|
||||
delete reshapedMask;
|
||||
*weights += (reshapedMask - 1) * 1e9;
|
||||
}
|
||||
|
||||
nd4j::ops::softmax softmax;
|
||||
|
@ -175,14 +174,13 @@ namespace ops {
|
|||
preSoftmax /= factor;
|
||||
|
||||
if(mask != nullptr){
|
||||
NDArray* reshapedMask;
|
||||
NDArray reshapedMask;
|
||||
if(preSoftmax.rankOf() == 4){
|
||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
||||
}else{
|
||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1});
|
||||
}
|
||||
preSoftmax += (*reshapedMask - 1) * 1e9;
|
||||
delete reshapedMask;
|
||||
preSoftmax += (reshapedMask - 1) * 1e9;
|
||||
}
|
||||
|
||||
NDArray weights('c', weightShape, values->dataType(), block.launchContext());
|
||||
|
|
|
@ -70,7 +70,7 @@ namespace nd4j {
|
|||
float beta = T_ARG(2);
|
||||
int depth = INT_ARG(0);
|
||||
|
||||
helpers::lrnBP(*input, *gradO, *gradI, depth, bias, alpha, beta);
|
||||
helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -98,9 +98,9 @@ namespace ops {
|
|||
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
||||
|
||||
// Apply Attention
|
||||
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
|
||||
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
||||
nd4j::ops::dot_product_attention attention;
|
||||
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
||||
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
||||
|
||||
// Project attention results
|
||||
attnResults.permutei({0, 3, 1, 2});
|
||||
|
@ -111,11 +111,9 @@ namespace ops {
|
|||
mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {});
|
||||
projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize});
|
||||
projRes.permutei({0, 2, 1});
|
||||
output->assign(projRes);
|
||||
|
||||
delete projectedQueries;
|
||||
delete projectedKeys;
|
||||
delete projectedValues;
|
||||
// FIXME: bad for performance
|
||||
output->assign(projRes);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -227,9 +225,9 @@ namespace ops {
|
|||
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
||||
|
||||
// Apply Attention
|
||||
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
|
||||
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
||||
nd4j::ops::dot_product_attention attention;
|
||||
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
|
||||
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
|
||||
|
||||
// Project attention results
|
||||
attnResults.permutei({0, 3, 1, 2});
|
||||
|
@ -237,31 +235,25 @@ namespace ops {
|
|||
|
||||
// dLdWo
|
||||
auto epsPerm = eps->permute({0, 2, 1});
|
||||
auto epsPostReshape = epsPerm->reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
||||
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
||||
nd4j::ops::matmul_bp matmulBp;
|
||||
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
||||
matmulBp.execute({&attnResults, Wo, epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
||||
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
||||
|
||||
// dLdAttn
|
||||
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues->sizeAt(2)});
|
||||
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
||||
dLdPreWo.permutei({0, 2, 3, 1});
|
||||
|
||||
nd4j::ops::dot_product_attention_bp attentionBp;
|
||||
NDArray dLdProjectedQueries(projectedQueries->shapeInfo(), false, block.launchContext());
|
||||
NDArray dLdProjectedKeys(projectedKeys->shapeInfo(), false, block.launchContext());
|
||||
NDArray dLdProjectedValues(projectedValues->shapeInfo(), false, block.launchContext());
|
||||
attentionBp.execute({projectedQueries, projectedKeys, projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
|
||||
NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext());
|
||||
NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext());
|
||||
NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext());
|
||||
attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
|
||||
|
||||
AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext());
|
||||
AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext());
|
||||
AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext());
|
||||
|
||||
delete projectedQueries;
|
||||
delete projectedKeys;
|
||||
delete projectedValues;
|
||||
delete epsPerm;
|
||||
delete epsPostReshape;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -45,13 +45,13 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
|
|||
int arrLen = a->lengthOf();
|
||||
|
||||
// FIXME: this stuff should be single op call. No sense rolling over couple of arrays twice
|
||||
for(int i = 0; i < arrLen; ++i ) {
|
||||
for(int i = 0; i < arrLen; ++i ) {
|
||||
REQUIRE_TRUE(a->e<float>(i) > 0.f, 0, "BETAINC op: arrays a array must contain only elements > 0 !");
|
||||
REQUIRE_TRUE(b->e<float>(i) > 0.f, 0, "BETAINC op: arrays b array must contain only elements > 0 !");
|
||||
REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!");
|
||||
}
|
||||
|
||||
*output = helpers::betaInc(block.launchContext(), *a, *b, *x);
|
||||
helpers::betaInc(block.launchContext(), *a, *b, *x, *output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -48,10 +48,7 @@ namespace nd4j {
|
|||
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
|
||||
auto tArr = input->reshape(input->ordering(), shape);
|
||||
auto zArr = z->reshape(z->ordering(), shape);
|
||||
tArr->addRowVector(bias, zArr);
|
||||
|
||||
delete tArr;
|
||||
delete zArr;
|
||||
tArr.addRowVector(bias, &zArr);
|
||||
}
|
||||
|
||||
STORE_RESULT(*z);
|
||||
|
@ -87,13 +84,12 @@ namespace nd4j {
|
|||
// cnn case
|
||||
if (input->rankOf() == 4) {
|
||||
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
|
||||
epsilonNext2d->reshapei('c', {(int) bias->lengthOf(), -1});
|
||||
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
|
||||
|
||||
auto sum = epsilonNext2d->reduceAlongDimension(reduce::Sum, {1});
|
||||
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
|
||||
gradB->assign(sum);
|
||||
|
||||
delete sum;
|
||||
delete epsilonNext2d;
|
||||
} else if (input->rankOf() == 2) {
|
||||
// regular fully-connected case
|
||||
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_check_numerics)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto message = INPUT_VARIABLE(1);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite);
|
||||
REQUIRE_TRUE(allFinite.e<bool>(0), 0, "CheckNumerics: %s", message->e<std::string>(0).c_str());
|
||||
|
||||
if (!block.isInplace())
|
||||
output->assign(input);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(check_numerics) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(check_numerics) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, nd4j::DataType::UTF8)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -56,7 +56,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
DECLARE_SHAPE_FN(crop_and_resize) {
|
||||
auto in = inputShape->at(0);
|
||||
auto in = inputShape->at(1);
|
||||
|
||||
Nd4jLong outputShape[4];
|
||||
|
||||
|
@ -77,8 +77,13 @@ namespace nd4j {
|
|||
}
|
||||
DECLARE_TYPES(crop_and_resize) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||
// ->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {FLOAT32}) // as TF
|
||||
->setAllowedInputTypes(2, {ALL_INTS})
|
||||
->setAllowedInputTypes(3, {ALL_INTS})
|
||||
->setAllowedOutputTypes({FLOAT32}); // as TF
|
||||
// ->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,9 +47,9 @@ namespace ops {
|
|||
auto o = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (a->lengthOf() == 3) {
|
||||
helpers::_cross(block.launchContext(), a, b, o);
|
||||
helpers::cross(block.launchContext(), a, b, o);
|
||||
} else {
|
||||
helpers::_crossBatched(block.launchContext(), a, b, o);
|
||||
helpers::crossBatched(block.launchContext(), a, b, o);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue