Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14) * Modified strided_slice op to properly work with empty-like shapes. * Fixed test for reduce_mean with empty-like input. * [WIP] Last merge (#15) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks * [WIP] Fixing outstanding issues for NLP (#9) * Avoid using not-inited objects * Test fixed. * Redundant method avoided for models like FastText * KMeans++ implementation * KMeans++ implementation * Disable parallel execution * KMeans++ * Tests * Dev branch merge (#16) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Fix some issues on master (#17) * Fix DataVec test issue * Fix issue with dl4j SameDiff output layer * Dtype fix for lambda layers * #7912 BertIterator dtype fix (use float32 not global default) * [WIP] Next set of CUDA stuff (#7) New CUDA implementations and improvements * bad file * Dev branch master merge (#23) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Compatibility of deserialization (#18) Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * SameDiff: add activation gradient checking support for debugging (#19) * SameDiff gradient checker: first pass on activation gradient checks * Fixes + tests for activation gradient checking * Javadoc * [WIP] Some nd4j data type corrections (#20) * Adjust data type * Set correct Data type. * Size of proper data type. * fix averaged cpu load (#22) * SameDiff ops, TF import and fixes (#24) * CheckNumerics tests + fixes + misc fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fake quant Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * FakeQuantWithMinMaxArgs Signed-off-by: AlexDBlack <blacka101@gmail.com> * CheckNumerics fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * Exception tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for out of scope stack allocated var use Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ignore for known failing test (already logged issue) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Merge upstream to fork (#25) * Add thousand-separator commas to TotalParams (#7915) * Add thousand-separator commas to TotalParams The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them. * Add thousand-separator commas to MultiLayerNetwork Corresponding change to MultiLayerNetwork Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com> * Update contributing and issue/PR templates (#7934) Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix link to AdaDelta paper (#7942) Fix link to AdaDelta paper hosted on matthewzeiler.com Signed-off-by: Jxtps * Fixes, and ignores for known/logged failing issues (#7943) Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff + DL4J/SameDiff: Multiple fixes (#28) * #7919 HDF5 attribute buffer length fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7909 Arbiter constructor exception ux improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7925 RNN output layer length checks Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7939 Add listener for validating inputs are not incorrectly modified Signed-off-by: AlexDBlack <blacka101@gmail.com> * #7939 Integrate NonInplaceValidationListener into tests * #7844 DL4J SameDiff fixes for variable minibatch size * DL4J SameDiff fixes - ensure gradient for input placeholder is available Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweaks to ExternalErrorsFunction - use placeholders, make more robust * Another fix * More fixes * More SameDiff/DL4J fixes * Scope out scalar array creation in BaseScalarOp * Remove debug code Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] Final dev branch merge (#29) * SameDiff: convertDataType and gradient check util improvements (#12) * GradCheck util improvements * StopGradient constructor + test * SameDiff: Add datatype conversion * Javadoc and add DataType.isNumerical() * Small fix * Fix SameDiff TF import test cases intermediate naming (workaround for bad default) * TFGraphTestAllHelper: check intermediates in execution order * Add missing debug listener * [WIP] lstmBlock fix + other changes (#13) - fixes lstmBlock issue - changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer - CheckNumerics op - fixes for ReduceBool IsInfOrNan & IsFinite * Small test fix * CheckNumerics op wrapper * Compatibility of deserialization (#18) Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * SameDiff: add activation gradient checking support for debugging (#19) * SameDiff gradient checker: first pass on activation gradient checks * Fixes + tests for activation gradient checking * Javadoc * [WIP] Some nd4j data type corrections (#20) * Adjust data type * Set correct Data type. * Size of proper data type. * fix averaged cpu load (#22) * [WIP] Multiple dataset iterators (#27) * Splitting dataset into arbitrary number * Fixes * Multiple split of iterator * Test * Test * Some fixes * signature change * one more tweak Signed-off-by: raver119 <raver119@gmail.com> * one more test for sequential use of DataSetIteratorSplitter Signed-off-by: raver119 <raver119@gmail.com> * Fixes * Fixes * one more test for Alexander Signed-off-by: raver119 <raver119@gmail.com> * Some fixes * Some fixes * one more test for Alexander Signed-off-by: raver119 <raver119@gmail.com> * minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Some fixes * Some fixes * couple of assertions tweaked Signed-off-by: raver119 <raver119@gmail.com> * MDS splitter test :/ Signed-off-by: raver119 <raver119@gmail.com> * Minor refactoring * Multi dataset * Some fixes * More tests * Small number of test fixes/improvements (failures on CI) (#31) Signed-off-by: AlexDBlack <blacka101@gmail.com> * [WIP] More CUDA stuff (#26) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * LRN BP CUDA Signed-off-by: raver119 <raver119@gmail.com> * less memory Signed-off-by: raver119 <raver119@gmail.com> * Fixed bug with crop_and_resize op helper. * get rid of unnecessary index-calculation dunction Signed-off-by: Yurii <yurii@skymind.io> * Fixed sort with nth_element cuda-based helper. * Refactored nth_element. * Refactored nth_element op and tests. * Modified usage of dim array with sortTad routine. * Refactored main routine of helper for non_max_image_suppression op. * non_max_image_suppression op helper with cuda kernel implementation. Initial revision. * fix vol2col cuda kernel * meh Signed-off-by: raver119 <raver119@gmail.com> * topK concept Signed-off-by: raver119 <raver119@gmail.com> * unsorted topK with scanWitdh of 1 Signed-off-by: raver119 <raver119@gmail.com> * correct vol2col tests * sorted/unsorted topK Signed-off-by: raver119 <raver119@gmail.com> * implementation and fixing col2im/col2vol * Corrected usage flags with input/output with reverse op. * dup is const now Signed-off-by: raver119 <raver119@gmail.com> * percentile op Signed-off-by: raver119 <raver119@gmail.com> * group tests for mapool2d Signed-off-by: Yurii <yurii@skymind.io> * special test for george Signed-off-by: raver119 <raver119@gmail.com> * less threads for sortTad Signed-off-by: raver119 <raver119@gmail.com> * provide conv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * remove auther in sort tad kernel code Signed-off-by: Yurii <yurii@skymind.io> * provide depthwise_conv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * - max_pooling_with_argmax - null check for special use Signed-off-by: raver119 <raver119@gmail.com> * dts cuda Signed-off-by: raver119 <raver119@gmail.com> * provide sconv2d for cuda Signed-off-by: Yurii <yurii@skymind.io> * std cuda Signed-off-by: raver119 <raver119@gmail.com> * Refactored non_max_suppression op to conform TF implementation. * Improved suppression helper. * provide pooling3d for cuda Signed-off-by: Yurii <yurii@skymind.io> * minor lstm rearrangements Signed-off-by: raver119 <raver119@gmail.com> * more of minor lstm rearrangements Signed-off-by: raver119 <raver119@gmail.com> * (bi)dynamic_rnn Signed-off-by: raver119 <raver119@gmail.com> * templates init order Signed-off-by: raver119 <raver119@gmail.com> * Refactored non_max_suppression op. * Added cuda kernel for non_max_suppression. * CPU sort by key/value Signed-off-by: raver119 <raver119@gmail.com> * CPU sort TAD by key/value Signed-off-by: raver119 <raver119@gmail.com> * CPU sort TAD by key/value tests Signed-off-by: raver119 <raver119@gmail.com> * Eliminate compiler error with cuda implementation. * - repaired gradCheck in cuda - provide conv2d_bp for cuda Signed-off-by: Yurii <yurii@skymind.io> * missed signature Signed-off-by: raver119 <raver119@gmail.com> * provide depthwise_conv2d_bp for cuda Signed-off-by: Yurii <yurii@skymind.io> * Implementation of lup helper with cuda kernel. Initial commit. * further work on backprops for convolutions Signed-off-by: Yurii <yurii@skymind.io> * CUDA linear sort by key/val Signed-off-by: raver119 <raver119@gmail.com> * CUDA tad sort by key/val Signed-off-by: raver119 <raver119@gmail.com> * start providing of backprop for pooling2d/3d Signed-off-by: Yurii <yurii@skymind.io> * Added atomicAdd for bool datatype. * dynamic partition concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic partition concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic partition scalar CUDA Signed-off-by: raver119 <raver119@gmail.com> * important comment Signed-off-by: raver119 <raver119@gmail.com> * fix pooling2d/3d backprop helpers Signed-off-by: Yurii <yurii@skymind.io> * Added non-linear test with dynamic_partition. * Improved test for dynamic_partition. * dynamic_partition TAD concept Signed-off-by: raver119 <raver119@gmail.com> * - dynamic_partition TAD CUDA impl - dynamic_partition TAD CPU fix Signed-off-by: raver119 <raver119@gmail.com> * - rewrite cpu code for usampling2d/3d - write cuda code for usampling2d/3d Signed-off-by: Yurii <yurii@skymind.io> * dynamic_stitch CUDA vector case Signed-off-by: raver119 <raver119@gmail.com> * dynamic_stitch CUDA TAD case concept Signed-off-by: raver119 <raver119@gmail.com> * dynamic_stitch CUDA TAD case impl Signed-off-by: raver119 <raver119@gmail.com> * Added tests for dynamic_stitch 3D-4D cases. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * Fixed type check for dynamic stitch. * min/max bp Signed-off-by: raver119 <raver119@gmail.com> * rewrite code for upsampling2d/3d cpu Signed-off-by: Yurii <yurii@skymind.io> * reduce min/max/norm_max bp Signed-off-by: raver119 <raver119@gmail.com> * lup implementation. Additional enhancements. * provide code for upsamling2d/3d backprop Signed-off-by: Yurii <yurii@skymind.io> * weightedCrossEntropyWithLogits Signed-off-by: raver119 <raver119@gmail.com> * Fixed template math atomicMul for 64bit ints. * Refactored dynamic_partition_bp op. * inverseBroadcast fix Signed-off-by: raver119 <raver119@gmail.com> * DynamicPartitionBP test datatype fixed. * - nd4j_atomicMul Windows fix - cpu/NDArrayLambda.hpp excluded from CUDA Signed-off-by: raver119 <raver119@gmail.com>master
parent
cae4fc9760
commit
1170827c18
|
@ -31,7 +31,7 @@ public class TaskCreatorProvider {
|
||||||
}
|
}
|
||||||
return c.newInstance();
|
return c.newInstance();
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
throw new RuntimeException("Could not create new instance of task creator class: " + c, e);
|
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
|
||||||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||||
return clazz.newInstance();
|
return clazz.newInstance();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
|
||||||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
||||||
return clazz.newInstance();
|
return clazz.newInstance();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ public abstract class BaseNetScoreFunction implements ScoreFunction {
|
||||||
ds.configure(dataSourceProperties);
|
ds.configure(dataSourceProperties);
|
||||||
}
|
}
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e);
|
||||||
}
|
}
|
||||||
return score(model, ds.testData());
|
return score(model, ds.testData());
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,10 +188,15 @@ public class ComputationGraphTaskCreator implements TaskCreator {
|
||||||
//For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both
|
//For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both
|
||||||
MultiDataSetIterator iterator;
|
MultiDataSetIterator iterator;
|
||||||
if(dataSource != null){
|
if(dataSource != null){
|
||||||
|
try {
|
||||||
DataSource dsInstance = dataSource.newInstance();
|
DataSource dsInstance = dataSource.newInstance();
|
||||||
if(dataSourceProperties != null)
|
if (dataSourceProperties != null)
|
||||||
dsInstance.configure(dataSourceProperties);
|
dsInstance.configure(dataSourceProperties);
|
||||||
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
|
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
|
||||||
|
" - no zero-arg constructor?",e);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters()));
|
iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,7 +190,8 @@ public class MultiLayerNetworkTaskCreator implements TaskCreator {
|
||||||
try{
|
try{
|
||||||
dsInstance = dataSource.newInstance();
|
dsInstance = dataSource.newInstance();
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName());
|
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
|
||||||
|
" - no zero-arg constructor?",e);
|
||||||
}
|
}
|
||||||
if(dataSourceProperties != null)
|
if(dataSourceProperties != null)
|
||||||
dsInstance.configure(dataSourceProperties);
|
dsInstance.configure(dataSourceProperties);
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
|
||||||
import org.datavec.api.writable.Text;
|
import org.datavec.api.writable.Text;
|
||||||
import org.datavec.api.writable.Writable;
|
import org.datavec.api.writable.Writable;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
@ -78,14 +79,14 @@ public class TestNDArrayWritableTransforms {
|
||||||
assertEquals(expColNames, tp.getFinalSchema().getColumnNames());
|
assertEquals(expColNames, tp.getFinalSchema().getColumnNames());
|
||||||
|
|
||||||
|
|
||||||
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
|
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
|
||||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)));
|
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)));
|
||||||
List<Writable> out = tp.execute(in);
|
List<Writable> out = tp.execute(in);
|
||||||
|
|
||||||
List<Writable> exp =
|
List<Writable> exp =
|
||||||
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
|
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
|
||||||
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)),
|
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)),
|
||||||
new NDArrayWritable(Nd4j.linspace(0, 9, 10).addi(2.0)));
|
new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE, 0, 10, 1).addi(2.0).reshape(1,10)));
|
||||||
|
|
||||||
assertEquals(exp, out);
|
assertEquals(exp, out);
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,9 +20,15 @@ import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class DataSetSplitterTests extends BaseDL4JTest {
|
public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
@Test
|
@Test
|
||||||
|
@ -39,7 +45,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
int gcntTest = 0;
|
int gcntTest = 0;
|
||||||
int global = 0;
|
int global = 0;
|
||||||
// emulating epochs here
|
// emulating epochs here
|
||||||
for (int e = 0; e < numEpochs; e++){
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
@ -79,7 +85,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
int gcntTest = 0;
|
int gcntTest = 0;
|
||||||
int global = 0;
|
int global = 0;
|
||||||
// emulating epochs here
|
// emulating epochs here
|
||||||
for (int e = 0; e < numEpochs; e++){
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
@ -117,7 +123,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
int gcntTest = 0;
|
int gcntTest = 0;
|
||||||
int global = 0;
|
int global = 0;
|
||||||
// emulating epochs here
|
// emulating epochs here
|
||||||
for (int e = 0; e < numEpochs; e++){
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
while (train.hasNext()) {
|
while (train.hasNext()) {
|
||||||
val data = train.next().getFeatures();
|
val data = train.next().getFeatures();
|
||||||
|
@ -144,4 +150,245 @@ public class DataSetSplitterTests extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(1000 * numEpochs, global);
|
assertEquals(1000 * numEpochs, global);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSplitter_4() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, 1000, new double[]{0.5, 0.3, 0.2});
|
||||||
|
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
val numEpochs = 10;
|
||||||
|
int global = 0;
|
||||||
|
// emulating epochs here
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int iterNo = 0;
|
||||||
|
int perEpoch = 0;
|
||||||
|
for (val partIterator : iteratorList) {
|
||||||
|
int cnt = 0;
|
||||||
|
partIterator.reset();
|
||||||
|
while (partIterator.hasNext()) {
|
||||||
|
val data = partIterator.next().getFeatures();
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||||
|
(float) perEpoch, data.getFloat(0), 1e-5);
|
||||||
|
//gcntTrain++;
|
||||||
|
global++;
|
||||||
|
cnt++;
|
||||||
|
++perEpoch;
|
||||||
|
}
|
||||||
|
++iterNo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1000* numEpochs, global);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSplitter_5() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{900, 100});
|
||||||
|
|
||||||
|
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
val numEpochs = 10;
|
||||||
|
|
||||||
|
int global = 0;
|
||||||
|
// emulating epochs here
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int iterNo = 0;
|
||||||
|
int perEpoch = 0;
|
||||||
|
for (val partIterator : iteratorList) {
|
||||||
|
partIterator.reset();
|
||||||
|
while (partIterator.hasNext()) {
|
||||||
|
int cnt = 0;
|
||||||
|
val data = partIterator.next().getFeatures();
|
||||||
|
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||||
|
(float) perEpoch, data.getFloat(0), 1e-5);
|
||||||
|
//gcntTrain++;
|
||||||
|
global++;
|
||||||
|
cnt++;
|
||||||
|
++perEpoch;
|
||||||
|
}
|
||||||
|
++iterNo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1000 * numEpochs, global);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSplitter_6() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
// we're going to mimic train+test+validation split
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||||
|
|
||||||
|
assertEquals(3, splitter.getIterators().size());
|
||||||
|
|
||||||
|
val trainIter = splitter.getIterators().get(0);
|
||||||
|
val testIter = splitter.getIterators().get(1);
|
||||||
|
val validationIter = splitter.getIterators().get(2);
|
||||||
|
|
||||||
|
// we're going to have multiple epochs
|
||||||
|
int numEpochs = 10;
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int globalIter = 0;
|
||||||
|
trainIter.reset();
|
||||||
|
testIter.reset();
|
||||||
|
validationIter.reset();
|
||||||
|
|
||||||
|
boolean trained = false;
|
||||||
|
while (trainIter.hasNext()) {
|
||||||
|
trained = true;
|
||||||
|
val ds = trainIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||||
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
// test set is used every epoch
|
||||||
|
boolean tested = false;
|
||||||
|
//testIter.reset();
|
||||||
|
while (testIter.hasNext()) {
|
||||||
|
tested = true;
|
||||||
|
val ds = testIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
|
// validation set is used every 5 epochs
|
||||||
|
if (e % 5 == 0) {
|
||||||
|
boolean validated = false;
|
||||||
|
//validationIter.reset();
|
||||||
|
while (validationIter.hasNext()) {
|
||||||
|
validated = true;
|
||||||
|
val ds = validationIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
if (e % 5 == 0)
|
||||||
|
assertEquals(1000, globalIter);
|
||||||
|
else
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
trainIter.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_1() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{500, 500});
|
||||||
|
|
||||||
|
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
val numEpochs = 10;
|
||||||
|
|
||||||
|
int global = 0;
|
||||||
|
// emulating epochs here
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
|
||||||
|
// Get data from second part, then rewind for the first one.
|
||||||
|
int cnt = 0;
|
||||||
|
int partNumber = 1;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
|
||||||
|
cnt++;
|
||||||
|
global++;
|
||||||
|
}
|
||||||
|
iteratorList.get(partNumber).reset();
|
||||||
|
partNumber = 0;
|
||||||
|
cnt = 0;
|
||||||
|
while (iteratorList.get(0).hasNext()) {
|
||||||
|
val data = iteratorList.get(0).next().getFeatures();
|
||||||
|
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
|
||||||
|
global++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_2() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{2});
|
||||||
|
|
||||||
|
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
|
||||||
|
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
|
||||||
|
int cnt = 0;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
|
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_3() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{10});
|
||||||
|
|
||||||
|
List<DataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
Random random = new Random();
|
||||||
|
int[] indexes = new int[iteratorList.size()];
|
||||||
|
for (int i = 0; i < indexes.length; ++i) {
|
||||||
|
indexes[i] = random.nextInt(iteratorList.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int partNumber : indexes) {
|
||||||
|
int cnt = 0;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
|
||||||
|
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_4() {
|
||||||
|
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
// we're going to mimic train+test+validation split
|
||||||
|
val splitter = new DataSetIteratorSplitter(back, new int[]{80, 10, 5});
|
||||||
|
|
||||||
|
assertEquals(3, splitter.getIterators().size());
|
||||||
|
|
||||||
|
val trainIter = splitter.getIterators().get(0); // 0..79
|
||||||
|
val testIter = splitter.getIterators().get(1); // 80 ..89
|
||||||
|
val validationIter = splitter.getIterators().get(2); // 90..94
|
||||||
|
|
||||||
|
// we're skipping train/test and go for validation first. we're that crazy, right.
|
||||||
|
int valCnt = 0;
|
||||||
|
while (validationIter.hasNext()) {
|
||||||
|
val ds = validationIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
|
||||||
|
valCnt++;
|
||||||
|
}
|
||||||
|
assertEquals(5, valCnt);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,11 +18,17 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
|
||||||
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
|
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -150,4 +156,309 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(1000 * numEpochs, global);
|
assertEquals(1000 * numEpochs, global);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiSplitter_1() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
// we're going to mimic train+test+validation split
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||||
|
|
||||||
|
assertEquals(3, splitter.getIterators().size());
|
||||||
|
|
||||||
|
val trainIter = splitter.getIterators().get(0);
|
||||||
|
val testIter = splitter.getIterators().get(1);
|
||||||
|
val validationIter = splitter.getIterators().get(2);
|
||||||
|
|
||||||
|
// we're going to have multiple epochs
|
||||||
|
int numEpochs = 10;
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int globalIter = 0;
|
||||||
|
trainIter.reset();
|
||||||
|
testIter.reset();
|
||||||
|
validationIter.reset();
|
||||||
|
|
||||||
|
boolean trained = false;
|
||||||
|
while (trainIter.hasNext()) {
|
||||||
|
trained = true;
|
||||||
|
val ds = trainIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||||
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
// test set is used every epoch
|
||||||
|
boolean tested = false;
|
||||||
|
//testIter.reset();
|
||||||
|
while (testIter.hasNext()) {
|
||||||
|
tested = true;
|
||||||
|
val ds = testIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
|
// validation set is used every 5 epochs
|
||||||
|
if (e % 5 == 0) {
|
||||||
|
boolean validated = false;
|
||||||
|
//validationIter.reset();
|
||||||
|
while (validationIter.hasNext()) {
|
||||||
|
validated = true;
|
||||||
|
val ds = validationIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
if (e % 5 == 0)
|
||||||
|
assertEquals(1000, globalIter);
|
||||||
|
else
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
trainIter.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSplitter_5() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{900, 100});
|
||||||
|
|
||||||
|
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
val numEpochs = 10;
|
||||||
|
|
||||||
|
int global = 0;
|
||||||
|
// emulating epochs here
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int iterNo = 0;
|
||||||
|
int perEpoch = 0;
|
||||||
|
for (val partIterator : iteratorList) {
|
||||||
|
partIterator.reset();
|
||||||
|
while (partIterator.hasNext()) {
|
||||||
|
int cnt = 0;
|
||||||
|
val data = partIterator.next().getFeatures();
|
||||||
|
|
||||||
|
for (int i = 0; i < data.length; ++i) {
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
|
||||||
|
(float) perEpoch, data[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
//gcntTrain++;
|
||||||
|
global++;
|
||||||
|
cnt++;
|
||||||
|
++perEpoch;
|
||||||
|
}
|
||||||
|
++iterNo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1000 * numEpochs, global);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSplitter_6() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
// we're going to mimic train+test+validation split
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
|
||||||
|
|
||||||
|
assertEquals(3, splitter.getIterators().size());
|
||||||
|
|
||||||
|
val trainIter = splitter.getIterators().get(0);
|
||||||
|
val testIter = splitter.getIterators().get(1);
|
||||||
|
val validationIter = splitter.getIterators().get(2);
|
||||||
|
|
||||||
|
// we're going to have multiple epochs
|
||||||
|
int numEpochs = 10;
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
int globalIter = 0;
|
||||||
|
trainIter.reset();
|
||||||
|
testIter.reset();
|
||||||
|
validationIter.reset();
|
||||||
|
|
||||||
|
boolean trained = false;
|
||||||
|
while (trainIter.hasNext()) {
|
||||||
|
trained = true;
|
||||||
|
val ds = trainIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
||||||
|
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", trained);
|
||||||
|
assertEquals(800, globalIter);
|
||||||
|
|
||||||
|
|
||||||
|
// test set is used every epoch
|
||||||
|
boolean tested = false;
|
||||||
|
//testIter.reset();
|
||||||
|
while (testIter.hasNext()) {
|
||||||
|
tested = true;
|
||||||
|
val ds = testIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", tested);
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
|
||||||
|
// validation set is used every 5 epochs
|
||||||
|
if (e % 5 == 0) {
|
||||||
|
boolean validated = false;
|
||||||
|
//validationIter.reset();
|
||||||
|
while (validationIter.hasNext()) {
|
||||||
|
validated = true;
|
||||||
|
val ds = validationIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
|
||||||
|
ds.getFeatures()[i].getDouble(0), 1e-5f);
|
||||||
|
}
|
||||||
|
globalIter++;
|
||||||
|
}
|
||||||
|
assertTrue("Failed at epoch [" + e + "]", validated);
|
||||||
|
}
|
||||||
|
|
||||||
|
// all 3 iterators have exactly 1000 elements combined
|
||||||
|
if (e % 5 == 0)
|
||||||
|
assertEquals(1000, globalIter);
|
||||||
|
else
|
||||||
|
assertEquals(900, globalIter);
|
||||||
|
trainIter.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_1() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{500, 500});
|
||||||
|
|
||||||
|
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
val numEpochs = 10;
|
||||||
|
|
||||||
|
int global = 0;
|
||||||
|
// emulating epochs here
|
||||||
|
for (int e = 0; e < numEpochs; e++) {
|
||||||
|
|
||||||
|
// Get data from second part, then rewind for the first one.
|
||||||
|
int cnt = 0;
|
||||||
|
int partNumber = 1;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
int farCnt = (1000 / 2) * (partNumber) + cnt;
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
for (int i = 0; i < data.length; ++i) {
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
cnt++;
|
||||||
|
global++;
|
||||||
|
}
|
||||||
|
iteratorList.get(partNumber).reset();
|
||||||
|
partNumber = 0;
|
||||||
|
cnt = 0;
|
||||||
|
while (iteratorList.get(0).hasNext()) {
|
||||||
|
val data = iteratorList.get(0).next().getFeatures();
|
||||||
|
for (int i = 0; i < data.length; ++i) {
|
||||||
|
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
|
||||||
|
data[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
global++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_2() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{2});
|
||||||
|
|
||||||
|
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
|
||||||
|
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
|
||||||
|
int cnt = 0;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
for (int i = 0; i < data.length; ++i) {
|
||||||
|
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_3() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10});
|
||||||
|
|
||||||
|
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
|
||||||
|
Random random = new Random();
|
||||||
|
int[] indexes = new int[iteratorList.size()];
|
||||||
|
for (int i = 0; i < indexes.length; ++i) {
|
||||||
|
indexes[i] = random.nextInt(iteratorList.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int partNumber : indexes) {
|
||||||
|
int cnt = 0;
|
||||||
|
while (iteratorList.get(partNumber).hasNext()) {
|
||||||
|
val data = iteratorList.get(partNumber).next().getFeatures();
|
||||||
|
for (int i = 0; i < data.length; ++i) {
|
||||||
|
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
|
||||||
|
data[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUnorderedSplitter_4() {
|
||||||
|
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
|
||||||
|
|
||||||
|
// we're going to mimic train+test+validation split
|
||||||
|
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{80, 10, 5});
|
||||||
|
|
||||||
|
assertEquals(3, splitter.getIterators().size());
|
||||||
|
|
||||||
|
val trainIter = splitter.getIterators().get(0); // 0..79
|
||||||
|
val testIter = splitter.getIterators().get(1); // 80 ..89
|
||||||
|
val validationIter = splitter.getIterators().get(2); // 90..94
|
||||||
|
|
||||||
|
// we're skipping train/test and go for validation first. we're that crazy, right.
|
||||||
|
int valCnt = 0;
|
||||||
|
while (validationIter.hasNext()) {
|
||||||
|
val ds = validationIter.next();
|
||||||
|
assertNotNull(ds);
|
||||||
|
for (int i = 0; i < ds.getFeatures().length; ++i) {
|
||||||
|
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
|
||||||
|
ds.getFeatures()[i].getFloat(0), 1e-5);
|
||||||
|
}
|
||||||
|
valCnt++;
|
||||||
|
}
|
||||||
|
assertEquals(5, valCnt);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.dropout.TestDropout;
|
||||||
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
@ -196,4 +197,43 @@ public class TestRnnLayers extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMismatchedInputLabelLength(){
|
||||||
|
|
||||||
|
for( int i=0; i<2; i++ ){
|
||||||
|
|
||||||
|
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
||||||
|
|
||||||
|
.list()
|
||||||
|
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
|
||||||
|
|
||||||
|
switch (i){
|
||||||
|
case 0:
|
||||||
|
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException();
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = lb.build();
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
||||||
|
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
|
||||||
|
|
||||||
|
try{
|
||||||
|
net.fit(in,l);
|
||||||
|
} catch (Throwable t){
|
||||||
|
String msg = t.getMessage();
|
||||||
|
assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,7 +249,6 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore("AB 2019/05/31 - Failing on CI and locally - see issues 7820 and 7657")
|
|
||||||
public void testCorrectness1() {
|
public void testCorrectness1() {
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
|
@ -270,30 +269,18 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
|
||||||
.useAdaGrad(false).build();
|
.useAdaGrad(false).build();
|
||||||
|
|
||||||
b.fit(data);
|
b.fit(data);
|
||||||
System.out.println(b.getData());
|
|
||||||
|
|
||||||
/*double[] expectedData = new double[]{15.5392794313924, 19.25226403656672, -5.194955746137196, -31.787679714614757, 48.8674725273665,
|
double[] expectedData = new double[]{ 63.8206, 80.4013, -19.4424, -140.4326, 198.7239,
|
||||||
24.92775755686273, -22.621939920239065, -29.790772278125395, 19.027362415188914, -16.013800175884274,
|
106.1148, -96.6273, -124.3634, 78.4174, -83.6621,
|
||||||
-27.454680593309185, 1.2929960811295493, -40.45000061571038, 61.23261682914338, 5.62278768938746,
|
-121.8706, 3.0888, -172.8560, 255.1262, 20.7021,
|
||||||
-28.16665244970911, -20.05502814088798, 12.803274346870865, -24.877262522905497, 45.115883138175874,
|
-120.7942, -78.1829, 56.6021, -112.3294, 185.4084,
|
||||||
21.597495694710616, 18.63254779638783, -4.029728632528419, -0.4596087279592638, -42.35340705500429,
|
88.5330, 78.0497, -18.8673, -11.0155, -175.1564,
|
||||||
-69.24727547461491, 40.94332685199673, -24.60866142208024, 17.689874972878723, -3.6779759693605314,
|
-297.8463, 174.2511, -103.8793, 72.5455, -15.8498,
|
||||||
-30.91803590368529, 10.645452930824145, 36.58583235020565, -64.74975614289316, -39.364099390585956,
|
-134.5235, 42.3300, 154.0391, -280.1010, -167.9765,
|
||||||
72.54886481127016, -35.30663155696714, 19.37116912936714, -7.790876543092118, 19.6586396288508,
|
306.9938, -150.9666, 83.4419, -36.0877, 83.9992,
|
||||||
58.1332709511154, -18.49217368496203, -3.5050200971182424, 5.662891294031322, 39.69533295638775,
|
245.1813, -81.5018, -14.8430, 16.1557, 166.8651,
|
||||||
-15.114610550011662, -32.42366951357609, 17.039297537056537, 42.25610885633673, -2.7013781552769904,
|
-65.9247, -138.1783, 72.5444, 176.3088, -25.6732,
|
||||||
-16.338582630617925, 41.734027526336874, 20.941332646863426, -3.2145240561108244, -45.36033539684912};*/
|
-69.6843, 167.3360, 87.6238, -18.5874, -187.3806};
|
||||||
double[] expectedData = {40.93810899235225, 50.90183660191448, -14.298857560948981, -86.2012232604988, 129.51281793466023,
|
|
||||||
66.29136854264247, -61.650213611972326, -80.42836756633497, 50.28325210727952, -44.29008119040566,
|
|
||||||
-74.82748570869279, 2.0170536250746807, -109.21462846594635, 162.3973196127918, 14.000621153511705,
|
|
||||||
-76.30892822919527, -54.251704596942275, 33.99763310539589, -67.6307009607032, 119.50868525237786,
|
|
||||||
57.17786598853867, 49.1489174572297, -11.25663463504983, -2.38899196609398, -114.27194947404686,
|
|
||||||
-185.93832011474473, 108.9022579845252, -66.14099037301474, 47.13683038425694, -10.037893631405792,
|
|
||||||
-83.88458799629637, 26.985651418254996, 96.68139337135332, -174.2832443285551, -106.0999118697521,
|
|
||||||
193.02622700008175, -94.88003359113081, 51.39502524568139, -20.96021960048648, 52.32291574424741,
|
|
||||||
154.33973608321477, -50.90644802585217, -10.345744416395354, 13.721222143380892, 105.2111073677489,
|
|
||||||
-41.339268919407345, -87.73042354938127, 45.306865238870046, 112.53877133856602, -8.44454352074299,
|
|
||||||
-44.660828600669056, 110.72662022978719, 55.74660833987147, -9.613556053471232, -122.19953914048916};
|
|
||||||
|
|
||||||
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
|
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
|
||||||
for (int i = 0; i < expectedArray.rows(); ++i)
|
for (int i = 0; i < expectedArray.rows(); ++i)
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.util;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -30,7 +31,7 @@ public class TimeSeriesUtilsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMovingAverage() {
|
public void testMovingAverage() {
|
||||||
INDArray a = Nd4j.arange(0, 20);
|
INDArray a = Nd4j.arange(0, 20).castTo(DataType.DOUBLE);
|
||||||
INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f,
|
INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f,
|
||||||
12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f});
|
12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f});
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
@ -42,14 +43,20 @@ public class DataSetIteratorSplitter {
|
||||||
protected DataSetIterator backedIterator;
|
protected DataSetIterator backedIterator;
|
||||||
protected final long totalExamples;
|
protected final long totalExamples;
|
||||||
protected final double ratio;
|
protected final double ratio;
|
||||||
|
protected final double[] ratios;
|
||||||
protected final long numTrain;
|
protected final long numTrain;
|
||||||
protected final long numTest;
|
protected final long numTest;
|
||||||
|
protected final long numArbitrarySets;
|
||||||
|
protected final int[] splits;
|
||||||
|
|
||||||
|
|
||||||
protected AtomicLong counter = new AtomicLong(0);
|
protected AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||||
protected DataSet firstTrain = null;
|
protected DataSet firstTrain = null;
|
||||||
|
|
||||||
|
protected int partNumber = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The only constructor
|
* The only constructor
|
||||||
*
|
*
|
||||||
|
@ -71,17 +78,94 @@ public class DataSetIteratorSplitter {
|
||||||
this.backedIterator = baseIterator;
|
this.backedIterator = baseIterator;
|
||||||
this.totalExamples = totalBatches;
|
this.totalExamples = totalBatches;
|
||||||
this.ratio = ratio;
|
this.ratio = ratio;
|
||||||
|
this.ratios = null;
|
||||||
this.numTrain = (long) (totalExamples * ratio);
|
this.numTrain = (long) (totalExamples * ratio);
|
||||||
this.numTest = totalExamples - numTrain;
|
this.numTest = totalExamples - numTrain;
|
||||||
|
this.numArbitrarySets = 2;
|
||||||
|
this.splits = null;
|
||||||
|
|
||||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double[] ratios) {
|
||||||
|
for (double ratio : ratios) {
|
||||||
|
if (!(ratio > 0.0 && ratio < 1.0))
|
||||||
|
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (totalBatches < 0)
|
||||||
|
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||||
|
|
||||||
|
if (!baseIterator.resetSupported())
|
||||||
|
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||||
|
|
||||||
|
|
||||||
|
this.backedIterator = baseIterator;
|
||||||
|
this.totalExamples = totalBatches;
|
||||||
|
this.ratio = 0.0;
|
||||||
|
this.ratios = ratios;
|
||||||
|
this.numTrain = 0; //(long) (totalExamples * ratio);
|
||||||
|
this.numTest = 0; //totalExamples - numTrain;
|
||||||
|
this.numArbitrarySets = ratios.length;
|
||||||
|
|
||||||
|
this.splits = new int[this.ratios.length];
|
||||||
|
for (int i = 0; i < this.splits.length; ++i) {
|
||||||
|
this.splits[i] = (int)(totalExamples * ratios[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
|
}
|
||||||
|
|
||||||
|
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, int[] splits) {
|
||||||
|
|
||||||
|
/*if (!(simpleRatio > 0.0 && simpleRatio < 1.0))
|
||||||
|
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");*/
|
||||||
|
|
||||||
|
int totalBatches = 0;
|
||||||
|
for (val v:splits)
|
||||||
|
totalBatches += v;
|
||||||
|
|
||||||
|
if (totalBatches < 0)
|
||||||
|
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||||
|
|
||||||
|
if (!baseIterator.resetSupported())
|
||||||
|
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||||
|
|
||||||
|
|
||||||
|
this.backedIterator = baseIterator;
|
||||||
|
this.totalExamples = totalBatches;
|
||||||
|
this.ratio = 0.0;
|
||||||
|
this.ratios = null;
|
||||||
|
|
||||||
|
this.numTrain = 0; //(long) (totalExamples * ratio);
|
||||||
|
this.numTest = 0; //totalExamples - numTrain;
|
||||||
|
this.splits = splits;
|
||||||
|
this.numArbitrarySets = splits.length;
|
||||||
|
|
||||||
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<DataSetIterator> getIterators() {
|
||||||
|
List<DataSetIterator> retVal = new ArrayList<>();
|
||||||
|
int partN = 0;
|
||||||
|
int bottom = 0;
|
||||||
|
for (final int split : splits) {
|
||||||
|
ScrollableDataSetIterator partIterator =
|
||||||
|
new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain,
|
||||||
|
new int[]{bottom,split});
|
||||||
|
bottom += split;
|
||||||
|
retVal.add(partIterator);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns train iterator instance
|
* This method returns train iterator instance
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public DataSetIterator getTrainIterator() {
|
public DataSetIterator getTrainIterator() {
|
||||||
return new DataSetIterator() {
|
return new DataSetIterator() {
|
||||||
@Override
|
@Override
|
||||||
|
@ -184,6 +268,7 @@ public class DataSetIteratorSplitter {
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public DataSetIterator getTestIterator() {
|
public DataSetIterator getTestIterator() {
|
||||||
return new DataSetIterator() {
|
return new DataSetIterator() {
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,9 +21,12 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
@ -43,6 +46,9 @@ public class MultiDataSetIteratorSplitter {
|
||||||
protected final double ratio;
|
protected final double ratio;
|
||||||
protected final long numTrain;
|
protected final long numTrain;
|
||||||
protected final long numTest;
|
protected final long numTest;
|
||||||
|
protected final double[] ratios;
|
||||||
|
protected final long numArbitrarySets;
|
||||||
|
protected final int[] splits;
|
||||||
|
|
||||||
protected AtomicLong counter = new AtomicLong(0);
|
protected AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
|
@ -71,15 +77,87 @@ public class MultiDataSetIteratorSplitter {
|
||||||
this.ratio = ratio;
|
this.ratio = ratio;
|
||||||
this.numTrain = (long) (totalExamples * ratio);
|
this.numTrain = (long) (totalExamples * ratio);
|
||||||
this.numTest = totalExamples - numTrain;
|
this.numTest = totalExamples - numTrain;
|
||||||
|
this.ratios = null;
|
||||||
|
this.numArbitrarySets = 0;
|
||||||
|
this.splits = null;
|
||||||
|
|
||||||
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
|
||||||
|
for (double ratio : ratios) {
|
||||||
|
if (!(ratio > 0.0 && ratio < 1.0))
|
||||||
|
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (totalBatches < 0)
|
||||||
|
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||||
|
|
||||||
|
if (!baseIterator.resetSupported())
|
||||||
|
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||||
|
|
||||||
|
|
||||||
|
this.backedIterator = baseIterator;
|
||||||
|
this.totalExamples = totalBatches;
|
||||||
|
this.ratio = 0.0;
|
||||||
|
this.numTrain = (long) (totalExamples * ratio);
|
||||||
|
this.numTest = totalExamples - numTrain;
|
||||||
|
this.ratios = null;
|
||||||
|
this.numArbitrarySets = ratios.length;
|
||||||
|
|
||||||
|
this.splits = new int[this.ratios.length];
|
||||||
|
for (int i = 0; i < this.splits.length; ++i) {
|
||||||
|
this.splits[i] = (int)(totalExamples * ratios[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
|
}
|
||||||
|
|
||||||
|
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) {
|
||||||
|
|
||||||
|
int totalBatches = 0;
|
||||||
|
for (val v:splits)
|
||||||
|
totalBatches += v;
|
||||||
|
|
||||||
|
if (totalBatches < 0)
|
||||||
|
throw new ND4JIllegalStateException("totalExamples number should be positive value");
|
||||||
|
|
||||||
|
if (!baseIterator.resetSupported())
|
||||||
|
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
|
||||||
|
|
||||||
|
|
||||||
|
this.backedIterator = baseIterator;
|
||||||
|
this.totalExamples = totalBatches;
|
||||||
|
this.ratio = 0.0;
|
||||||
|
this.numTrain = (long) (totalExamples * ratio);
|
||||||
|
this.numTest = totalExamples - numTrain;
|
||||||
|
this.ratios = null;
|
||||||
|
this.numArbitrarySets = splits.length;
|
||||||
|
this.splits = splits;
|
||||||
|
|
||||||
|
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<MultiDataSetIterator> getIterators() {
|
||||||
|
List<MultiDataSetIterator> retVal = new ArrayList<>();
|
||||||
|
int partN = 0;
|
||||||
|
int bottom = 0;
|
||||||
|
for (final int split : splits) {
|
||||||
|
ScrollableMultiDataSetIterator partIterator =
|
||||||
|
new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain,
|
||||||
|
new int[]{bottom,split});
|
||||||
|
bottom += split;
|
||||||
|
retVal.add(partIterator);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns train iterator instance
|
* This method returns train iterator instance
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public MultiDataSetIterator getTrainIterator() {
|
public MultiDataSetIterator getTrainIterator() {
|
||||||
return new MultiDataSetIterator() {
|
return new MultiDataSetIterator() {
|
||||||
@Override
|
@Override
|
||||||
|
@ -162,6 +240,7 @@ public class MultiDataSetIteratorSplitter {
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public MultiDataSetIterator getTestIterator() {
|
public MultiDataSetIterator getTestIterator() {
|
||||||
return new MultiDataSetIterator() {
|
return new MultiDataSetIterator() {
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,158 @@
|
||||||
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
public class ScrollableDataSetIterator implements DataSetIterator {
|
||||||
|
private int thisPart = 0;
|
||||||
|
private int top = 0;
|
||||||
|
private int bottom = 0;
|
||||||
|
protected DataSetIterator backedIterator;
|
||||||
|
protected AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
|
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||||
|
protected DataSet firstTrain = null;
|
||||||
|
protected MultiDataSet firstMultiTrain = null;
|
||||||
|
private double ratio;
|
||||||
|
private long totalExamples;
|
||||||
|
private long itemsPerPart;
|
||||||
|
private long current;
|
||||||
|
|
||||||
|
|
||||||
|
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
|
||||||
|
AtomicBoolean resetPending, DataSet firstTrain, double ratio,
|
||||||
|
int totalExamples) {
|
||||||
|
this.thisPart = num;
|
||||||
|
this.backedIterator = backedIterator;
|
||||||
|
this.counter = counter;
|
||||||
|
this.resetPending = resetPending;
|
||||||
|
this.firstTrain = firstTrain;
|
||||||
|
this.ratio = ratio;
|
||||||
|
this.totalExamples = totalExamples;
|
||||||
|
this.itemsPerPart = (long)(totalExamples * ratio);
|
||||||
|
this.current = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
|
||||||
|
AtomicBoolean resetPending, DataSet firstTrain,
|
||||||
|
int[] itemsPerPart) {
|
||||||
|
this.thisPart = num;
|
||||||
|
this.bottom = itemsPerPart[0];
|
||||||
|
this.top = bottom + itemsPerPart[1];
|
||||||
|
this.itemsPerPart = top;
|
||||||
|
|
||||||
|
this.backedIterator = backedIterator;
|
||||||
|
this.counter = counter;
|
||||||
|
//this.resetPending = resetPending;
|
||||||
|
this.firstTrain = firstTrain;
|
||||||
|
//this.totalExamples = totalExamples;
|
||||||
|
this.current = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSet next(int i) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> getLabels() {
|
||||||
|
return backedIterator.getLabels();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int inputColumns() {
|
||||||
|
return backedIterator.inputColumns();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void remove() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int totalOutcomes() {
|
||||||
|
return backedIterator.totalOutcomes();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean resetSupported() {
|
||||||
|
return backedIterator.resetSupported();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean asyncSupported() {
|
||||||
|
return backedIterator.asyncSupported();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
resetPending.set(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int batch() {
|
||||||
|
return backedIterator.batch();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
|
||||||
|
backedIterator.setPreProcessor(dataSetPreProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSetPreProcessor getPreProcessor() {
|
||||||
|
|
||||||
|
return backedIterator.getPreProcessor();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
if (resetPending.get()) {
|
||||||
|
if (resetSupported()) {
|
||||||
|
backedIterator.reset();
|
||||||
|
counter.set(0);
|
||||||
|
current = 0;
|
||||||
|
resetPending.set(false);
|
||||||
|
} else
|
||||||
|
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean state = false;
|
||||||
|
if (current >= top)
|
||||||
|
return false;
|
||||||
|
state = backedIterator.hasNext();
|
||||||
|
if (!state)
|
||||||
|
return false;
|
||||||
|
if (state && counter.get() < itemsPerPart)
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSet next() {
|
||||||
|
counter.incrementAndGet();
|
||||||
|
if ((current == 0) && (bottom != 0)) {
|
||||||
|
backedIterator.reset();
|
||||||
|
long cnt = current;
|
||||||
|
for (; cnt < bottom; ++cnt) {
|
||||||
|
if (backedIterator.hasNext())
|
||||||
|
backedIterator.next();
|
||||||
|
}
|
||||||
|
current = cnt+1;
|
||||||
|
}
|
||||||
|
else current++;
|
||||||
|
val p = backedIterator.next();
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
package org.deeplearning4j.datasets.iterator;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||||
|
|
||||||
|
import javax.naming.OperationNotSupportedException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
|
||||||
|
private int thisPart = 0;
|
||||||
|
private int top = 0;
|
||||||
|
private int bottom = 0;
|
||||||
|
protected MultiDataSetIterator backedIterator;
|
||||||
|
protected AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
|
protected AtomicBoolean resetPending = new AtomicBoolean(false);
|
||||||
|
protected DataSet firstTrain = null;
|
||||||
|
protected MultiDataSet firstMultiTrain = null;
|
||||||
|
private double ratio;
|
||||||
|
private long totalExamples;
|
||||||
|
private long itemsPerPart;
|
||||||
|
private long current;
|
||||||
|
|
||||||
|
public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter,
|
||||||
|
MultiDataSet firstTrain, int[] itemsPerPart) {
|
||||||
|
this.thisPart = num;
|
||||||
|
this.bottom = itemsPerPart[0];
|
||||||
|
this.top = bottom + itemsPerPart[1];
|
||||||
|
this.itemsPerPart = top;
|
||||||
|
|
||||||
|
this.counter = counter;
|
||||||
|
//this.resetPending = resetPending;
|
||||||
|
this.firstTrain = null;
|
||||||
|
this.firstMultiTrain = firstTrain;
|
||||||
|
//this.totalExamples = totalExamples;
|
||||||
|
this.current = 0;
|
||||||
|
this.backedIterator = backedIterator;
|
||||||
|
this.resetPending = resetPending;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean resetSupported() {
|
||||||
|
return backedIterator.resetSupported();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean asyncSupported() {
|
||||||
|
return backedIterator.asyncSupported();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
resetPending.set(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setPreProcessor(MultiDataSetPreProcessor dataSetPreProcessor) {
|
||||||
|
backedIterator.setPreProcessor(dataSetPreProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSetPreProcessor getPreProcessor() {
|
||||||
|
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean hasNext() {
|
||||||
|
if (resetPending.get()) {
|
||||||
|
if (resetSupported()) {
|
||||||
|
backedIterator.reset();
|
||||||
|
counter.set(0);
|
||||||
|
current = 0;
|
||||||
|
resetPending.set(false);
|
||||||
|
} else
|
||||||
|
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean state = false;
|
||||||
|
if (current >= top)
|
||||||
|
return false;
|
||||||
|
state = backedIterator.hasNext();
|
||||||
|
if (!state)
|
||||||
|
return false;
|
||||||
|
if (state && counter.get() < itemsPerPart)
|
||||||
|
return true;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet next() {
|
||||||
|
counter.incrementAndGet();
|
||||||
|
if ((current == 0) && (bottom != 0)) {
|
||||||
|
backedIterator.reset();
|
||||||
|
long cnt = current;
|
||||||
|
for (; cnt < bottom; ++cnt) {
|
||||||
|
if (backedIterator.hasNext())
|
||||||
|
backedIterator.next();
|
||||||
|
}
|
||||||
|
current = cnt+1;
|
||||||
|
}
|
||||||
|
else current++;
|
||||||
|
val p = backedIterator.next();
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MultiDataSet next(int i) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
}
|
|
@ -47,6 +47,8 @@ import static org.bytedeco.hdf5.global.hdf5.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Hdf5Archive implements Closeable {
|
public class Hdf5Archive implements Closeable {
|
||||||
|
|
||||||
|
public static final int MAX_BUFFER_SIZE_BYTES = (int)Math.pow(2, 28); //256 MB
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently
|
* HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently
|
||||||
* in multiple threads. This object is used for locking read etc activity using synchronized blocks
|
* in multiple threads. This object is used for locking read etc activity using synchronized blocks
|
||||||
|
@ -338,7 +340,7 @@ public class Hdf5Archive implements Closeable {
|
||||||
private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException {
|
private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException {
|
||||||
synchronized (Hdf5Archive.LOCK_OBJECT) {
|
synchronized (Hdf5Archive.LOCK_OBJECT) {
|
||||||
VarLenType vl = attribute.getVarLenType();
|
VarLenType vl = attribute.getVarLenType();
|
||||||
int bufferSizeMult = 1;
|
int currBufferLength = 2048;
|
||||||
String s;
|
String s;
|
||||||
/* TODO: find a less hacky way to do this.
|
/* TODO: find a less hacky way to do this.
|
||||||
* Reading variable length strings (from attributes) is a giant
|
* Reading variable length strings (from attributes) is a giant
|
||||||
|
@ -349,8 +351,8 @@ public class Hdf5Archive implements Closeable {
|
||||||
* buffer and repeat.
|
* buffer and repeat.
|
||||||
*/
|
*/
|
||||||
while (true) {
|
while (true) {
|
||||||
byte[] attrBuffer = new byte[bufferSizeMult * 2000];
|
byte[] attrBuffer = new byte[currBufferLength];
|
||||||
BytePointer attrPointer = new BytePointer(attrBuffer);
|
BytePointer attrPointer = new BytePointer(currBufferLength);
|
||||||
attribute.read(vl, attrPointer);
|
attribute.read(vl, attrPointer);
|
||||||
attrPointer.get(attrBuffer);
|
attrPointer.get(attrBuffer);
|
||||||
s = new String(attrBuffer);
|
s = new String(attrBuffer);
|
||||||
|
@ -362,9 +364,11 @@ public class Hdf5Archive implements Closeable {
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
//OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer
|
//OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer
|
||||||
}
|
}
|
||||||
bufferSizeMult *= 2;
|
|
||||||
if (bufferSizeMult > 1024) {
|
if(currBufferLength == MAX_BUFFER_SIZE_BYTES){
|
||||||
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute");
|
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute: size exceeds " + currBufferLength + " bytes");
|
||||||
|
} else {
|
||||||
|
currBufferLength = (int)Math.min(MAX_BUFFER_SIZE_BYTES, currBufferLength * 4L);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vl.deallocate();
|
vl.deallocate();
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
|
||||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
import org.deeplearning4j.clustering.cluster.Cluster;
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterUtils;
|
import org.deeplearning4j.clustering.cluster.ClusterUtils;
|
||||||
|
@ -62,12 +63,13 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
private ClusterSet clusterSet;
|
private ClusterSet clusterSet;
|
||||||
private List<Point> initialPoints;
|
private List<Point> initialPoints;
|
||||||
private transient ExecutorService exec;
|
private transient ExecutorService exec;
|
||||||
|
private boolean useKmeansPlusPlus;
|
||||||
|
|
||||||
|
|
||||||
|
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||||
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy) {
|
|
||||||
this.clusteringStrategy = clusteringStrategy;
|
this.clusteringStrategy = clusteringStrategy;
|
||||||
this.exec = MultiThreadUtils.newExecutorService();
|
this.exec = MultiThreadUtils.newExecutorService();
|
||||||
|
this.useKmeansPlusPlus = useKmeansPlusPlus;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -75,8 +77,8 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
* @param clusteringStrategy
|
* @param clusteringStrategy
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy) {
|
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||||
return new BaseClusteringAlgorithm(clusteringStrategy);
|
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -86,7 +88,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
*/
|
*/
|
||||||
public ClusterSet applyTo(List<Point> points) {
|
public ClusterSet applyTo(List<Point> points) {
|
||||||
resetState(points);
|
resetState(points);
|
||||||
initClusters();
|
initClusters(useKmeansPlusPlus);
|
||||||
iterations();
|
iterations();
|
||||||
return clusterSet;
|
return clusterSet;
|
||||||
}
|
}
|
||||||
|
@ -130,7 +132,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
* Initialize the
|
* Initialize the
|
||||||
* cluster centers at random
|
* cluster centers at random
|
||||||
*/
|
*/
|
||||||
protected void initClusters() {
|
protected void initClusters(boolean kMeansPlusPlus) {
|
||||||
log.info("Generating initial clusters");
|
log.info("Generating initial clusters");
|
||||||
List<Point> points = new ArrayList<>(initialPoints);
|
List<Point> points = new ArrayList<>(initialPoints);
|
||||||
|
|
||||||
|
@ -152,7 +154,10 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
|
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
|
||||||
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
|
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
|
||||||
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
|
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
|
||||||
double r = random.nextFloat() * dxs.maxNumber().doubleValue();
|
double summed = Nd4j.sum(dxs).getDouble(0);
|
||||||
|
double r = kMeansPlusPlus ? random.nextDouble() * summed:
|
||||||
|
random.nextFloat() * dxs.maxNumber().doubleValue();
|
||||||
|
|
||||||
for (int i = 0; i < dxs.length(); i++) {
|
for (int i = 0; i < dxs.length(); i++) {
|
||||||
double distance = dxs.getDouble(i);
|
double distance = dxs.getDouble(i);
|
||||||
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
|
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
|
||||||
|
@ -170,6 +175,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
|
||||||
new IterationInfo(currentIteration, initialClusterSetInfo));
|
new IterationInfo(currentIteration, initialClusterSetInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void applyClusteringStrategy() {
|
protected void applyClusteringStrategy() {
|
||||||
if (!isStrategyApplicableNow())
|
if (!isStrategyApplicableNow())
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -79,8 +79,8 @@ public class ClusterUtils {
|
||||||
int nClusters = clusterSet.getClusterCount();
|
int nClusters = clusterSet.getClusterCount();
|
||||||
for (int i = 0; i < nClusters; i++) {
|
for (int i = 0; i < nClusters; i++) {
|
||||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||||
tasks.add(new Runnable() {
|
//tasks.add(new Runnable() {
|
||||||
public void run() {
|
// public void run() {
|
||||||
try {
|
try {
|
||||||
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
||||||
refreshClusterCenter(cluster, clusterInfo);
|
refreshClusterCenter(cluster, clusterInfo);
|
||||||
|
@ -88,10 +88,10 @@ public class ClusterUtils {
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
log.warn("Error refreshing cluster centers", t);
|
log.warn("Error refreshing cluster centers", t);
|
||||||
}
|
}
|
||||||
|
// }
|
||||||
|
//});
|
||||||
}
|
}
|
||||||
});
|
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||||
}
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
|
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
|
||||||
|
@ -146,28 +146,29 @@ public class ClusterUtils {
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
List<Runnable> tasks = new ArrayList<>();
|
||||||
for (int i = 0; i < pointsCount; i++) {
|
for (int i = 0; i < pointsCount; i++) {
|
||||||
final int i2 = i;
|
final int i2 = i;
|
||||||
tasks.add(new Runnable() {
|
//tasks.add(new Runnable() {
|
||||||
public void run() {
|
// public void run() {
|
||||||
try {
|
try {
|
||||||
Point point = points.get(i2);
|
Point point = points.get(i2);
|
||||||
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
|
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
|
||||||
: Math.pow(newCluster.getDistanceToCenter(point), 2);
|
: Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||||
dxs.putScalar(i2, clusterSet.isInverse() ? dist : dist);
|
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
log.warn("Error computing squared distance from nearest cluster", t);
|
log.warn("Error computing squared distance from nearest cluster", t);
|
||||||
}
|
}
|
||||||
}
|
// }
|
||||||
});
|
//});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||||
|
|
||||||
for (int i = 0; i < pointsCount; i++) {
|
for (int i = 0; i < pointsCount; i++) {
|
||||||
double previousMinDistance = previousDxs.getDouble(i);
|
double previousMinDistance = previousDxs.getDouble(i);
|
||||||
if (clusterSet.isInverse()) {
|
if (clusterSet.isInverse()) {
|
||||||
if (dxs.getDouble(i) < previousMinDistance)
|
if (dxs.getDouble(i) < previousMinDistance) {
|
||||||
|
|
||||||
dxs.putScalar(i, previousMinDistance);
|
dxs.putScalar(i, previousMinDistance);
|
||||||
|
}
|
||||||
} else if (dxs.getDouble(i) > previousMinDistance)
|
} else if (dxs.getDouble(i) > previousMinDistance)
|
||||||
dxs.putScalar(i, previousMinDistance);
|
dxs.putScalar(i, previousMinDistance);
|
||||||
}
|
}
|
||||||
|
@ -175,6 +176,23 @@ public class ClusterUtils {
|
||||||
return dxs;
|
return dxs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
|
||||||
|
final List<Point> points, INDArray previousDxs) {
|
||||||
|
final int pointsCount = points.size();
|
||||||
|
final INDArray dxs = Nd4j.create(pointsCount);
|
||||||
|
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
||||||
|
|
||||||
|
Double sum = new Double(0);
|
||||||
|
for (int i = 0; i < pointsCount; i++) {
|
||||||
|
|
||||||
|
Point point = points.get(i);
|
||||||
|
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||||
|
sum += dist;
|
||||||
|
dxs.putScalar(i, sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dxs;
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param clusterSet
|
* @param clusterSet
|
||||||
|
@ -194,27 +212,27 @@ public class ClusterUtils {
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
List<Runnable> tasks = new ArrayList<>();
|
||||||
for (int i = 0; i < clusterCount; i++) {
|
for (int i = 0; i < clusterCount; i++) {
|
||||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||||
tasks.add(new Runnable() {
|
//tasks.add(new Runnable() {
|
||||||
public void run() {
|
// public void run() {
|
||||||
try {
|
try {
|
||||||
info.getClustersInfos().put(cluster.getId(),
|
info.getClustersInfos().put(cluster.getId(),
|
||||||
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
|
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
log.warn("Error computing cluster set info", t);
|
log.warn("Error computing cluster set info", t);
|
||||||
}
|
}
|
||||||
}
|
//}
|
||||||
});
|
//});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||||
|
|
||||||
tasks = new ArrayList<>();
|
//tasks = new ArrayList<>();
|
||||||
for (int i = 0; i < clusterCount; i++) {
|
for (int i = 0; i < clusterCount; i++) {
|
||||||
final int clusterIdx = i;
|
final int clusterIdx = i;
|
||||||
final Cluster fromCluster = clusterSet.getClusters().get(i);
|
final Cluster fromCluster = clusterSet.getClusters().get(i);
|
||||||
tasks.add(new Runnable() {
|
//tasks.add(new Runnable() {
|
||||||
public void run() {
|
//public void run() {
|
||||||
try {
|
try {
|
||||||
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
|
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
|
||||||
Cluster toCluster = clusterSet.getClusters().get(k);
|
Cluster toCluster = clusterSet.getClusters().get(k);
|
||||||
|
@ -230,12 +248,12 @@ public class ClusterUtils {
|
||||||
} catch (Throwable t) {
|
} catch (Throwable t) {
|
||||||
log.warn("Error computing distances", t);
|
log.warn("Error computing distances", t);
|
||||||
}
|
}
|
||||||
}
|
// }
|
||||||
});
|
//});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||||
|
|
||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,8 +37,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||||
*
|
*
|
||||||
* @param clusteringStrategy
|
* @param clusteringStrategy
|
||||||
*/
|
*/
|
||||||
protected KMeansClustering(ClusteringStrategy clusteringStrategy) {
|
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
|
||||||
super(clusteringStrategy);
|
super(clusteringStrategy, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -50,11 +50,11 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
|
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
|
||||||
boolean inverse) {
|
boolean inverse, boolean useKMeansPlusPlus) {
|
||||||
ClusteringStrategy clusteringStrategy =
|
ClusteringStrategy clusteringStrategy =
|
||||||
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
|
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
|
||||||
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
|
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
|
||||||
return new KMeansClustering(clusteringStrategy);
|
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -66,10 +66,10 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||||
boolean inverse, boolean allowEmptyClusters) {
|
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
|
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
|
||||||
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||||
return new KMeansClustering(clusteringStrategy);
|
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,8 +81,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||||
* @param distanceFunction the distance function to use for grouping
|
* @param distanceFunction the distance function to use for grouping
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction) {
|
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
|
||||||
return setup(clusterCount, maxIterationCount, distanceFunction, false);
|
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -94,17 +94,17 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||||
boolean allowEmptyClusters) {
|
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||||
return new KMeansClustering(clusteringStrategy);
|
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
|
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
|
||||||
boolean allowEmptyClusters) {
|
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
|
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
|
||||||
return new KMeansClustering(clusteringStrategy);
|
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kmeans;
|
package org.deeplearning4j.clustering.kmeans;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.time.StopWatch;
|
import org.apache.commons.lang3.time.StopWatch;
|
||||||
import org.deeplearning4j.clustering.BaseDL4JTest;
|
import org.deeplearning4j.clustering.BaseDL4JTest;
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||||
|
@ -28,36 +29,40 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by agibsonccc on 7/2/17.
|
* Created by agibsonccc on 7/2/17.
|
||||||
*/
|
*/
|
||||||
public class KMeansTest extends BaseDL4JTest {
|
public class KMeansTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
private boolean[] useKMeansPlusPlus = {true, false};
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKMeans() {
|
public void testKMeans() {
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN);
|
for (boolean mode : useKMeansPlusPlus) {
|
||||||
|
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode);
|
||||||
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
|
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
|
||||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||||
System.out.println(pointClassification);
|
System.out.println(pointClassification);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKmeansCosine() {
|
public void testKmeansCosine() {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
int numClusters = 5;
|
int numClusters = 5;
|
||||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
|
for (boolean mode : useKMeansPlusPlus) {
|
||||||
|
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
|
||||||
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
|
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
|
||||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||||
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
|
||||||
|
|
||||||
|
|
||||||
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN);
|
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
|
||||||
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
|
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
|
||||||
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
|
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
|
||||||
System.out.println("Cosine " + pointClassification);
|
System.out.println("Cosine " + pointClassification);
|
||||||
|
@ -66,6 +71,7 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
assertEquals(pointClassification.getCluster().getPoints().get(0),
|
assertEquals(pointClassification.getCluster().getPoints().get(0),
|
||||||
pointClassificationEuclidean.getCluster().getPoints().get(0));
|
pointClassificationEuclidean.getCluster().getPoints().get(0));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
@Test
|
@Test
|
||||||
|
@ -73,10 +79,11 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
int numClusters = 20;
|
int numClusters = 20;
|
||||||
|
for (boolean mode : useKMeansPlusPlus) {
|
||||||
StopWatch watch = new StopWatch();
|
StopWatch watch = new StopWatch();
|
||||||
watch.start();
|
watch.start();
|
||||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
|
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
|
||||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000*300, 5000*300).reshape(5000,300 ));
|
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300));
|
||||||
|
|
||||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||||
watch.stop();
|
watch.stop();
|
||||||
|
@ -90,6 +97,7 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
watch.stop();
|
watch.stop();
|
||||||
System.out.println("Elapsed for search: " + watch);
|
System.out.println("Elapsed for search: " + watch);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
|
@ -97,11 +105,12 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
int numClusters = 20;
|
int numClusters = 20;
|
||||||
|
for (boolean mode : useKMeansPlusPlus) {
|
||||||
StopWatch watch = new StopWatch();
|
StopWatch watch = new StopWatch();
|
||||||
watch.start();
|
watch.start();
|
||||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false);
|
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode);
|
||||||
|
|
||||||
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
|
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
|
||||||
|
|
||||||
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||||
watch.stop();
|
watch.stop();
|
||||||
|
@ -117,9 +126,9 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
|
|
||||||
watch.reset();
|
watch.reset();
|
||||||
watch.start();
|
watch.start();
|
||||||
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false);
|
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode);
|
||||||
|
|
||||||
points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
|
points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
|
||||||
|
|
||||||
clusterSet = kMeansClustering.applyTo(points);
|
clusterSet = kMeansClustering.applyTo(points);
|
||||||
watch.stop();
|
watch.stop();
|
||||||
|
@ -133,6 +142,7 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
watch.stop();
|
watch.stop();
|
||||||
System.out.println("Elapsed for search: " + watch);
|
System.out.println("Elapsed for search: " + watch);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCorrectness() {
|
public void testCorrectness() {
|
||||||
|
@ -141,7 +151,8 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
int numClusters = 3;
|
int numClusters = 3;
|
||||||
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, true);
|
for (boolean mode : useKMeansPlusPlus) {
|
||||||
|
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
|
||||||
double[] data = new double[]{
|
double[] data = new double[]{
|
||||||
15, 16,
|
15, 16,
|
||||||
16, 18.5,
|
16, 18.5,
|
||||||
|
@ -181,6 +192,7 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
for (int i = 0; i < clusters.size(); ++i)
|
for (int i = 0; i < clusters.size(); ++i)
|
||||||
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
|
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
/*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}),
|
/*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}),
|
||||||
pointClassification.getCluster().getCenter().getArray());*/
|
pointClassification.getCluster().getCenter().getArray());*/
|
||||||
|
|
||||||
|
@ -233,4 +245,39 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
System.out.println();
|
System.out.println();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInitClusters() {
|
||||||
|
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
|
||||||
|
Nd4j.getRandom().setSeed(7);
|
||||||
|
{
|
||||||
|
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true);
|
||||||
|
|
||||||
|
double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8},
|
||||||
|
{8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8},
|
||||||
|
{1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8},
|
||||||
|
{2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8},
|
||||||
|
{3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8},
|
||||||
|
{4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8},
|
||||||
|
{4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8},
|
||||||
|
{5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8},
|
||||||
|
{6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}};
|
||||||
|
INDArray data = Nd4j.createFromArray(dataArray);
|
||||||
|
List<Point> points = Point.toPoints(data);
|
||||||
|
|
||||||
|
ClusterSet clusterSet = kMeansClustering.applyTo(points);
|
||||||
|
|
||||||
|
double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8};
|
||||||
|
double[] centroid2 = {5.14e8, 5.41e8, 5.68e8, 5.95e8};
|
||||||
|
double[] centroid3 = {1.63e8, 1.9e8, 2.17e8, 2.44e8};
|
||||||
|
double[] centroid4 = {6.76e8, 7.03e8, 7.3e8, 7.57e8};
|
||||||
|
double[] centroid5 = {4.06e8, 4.33e8, 4.6e8, 4.87e8};
|
||||||
|
|
||||||
|
assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||||
|
assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||||
|
assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||||
|
assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||||
|
assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@ import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang.ArrayUtils;
|
import org.apache.commons.lang.ArrayUtils;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
|
@ -857,4 +859,34 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBackwardsCompatibleWord2Vec() {
|
||||||
|
File model_v3 = Resources.asFile("deeplearning4j-nlp/model_beta3.zip");
|
||||||
|
File model_v4 = Resources.asFile("deeplearning4j-nlp/model_beta4.zip");
|
||||||
|
Word2Vec word2Vec1 = WordVectorSerializer.readWord2VecModel(model_v3, true);
|
||||||
|
Word2Vec word2Vec2 = WordVectorSerializer.readWord2VecModel(model_v4, true);
|
||||||
|
try {
|
||||||
|
assertEquals(word2Vec1.toJson(), word2Vec2.toJson());
|
||||||
|
} catch (Exception e) {
|
||||||
|
fail(e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBackwardsCompatibleSequenceVectors() {
|
||||||
|
File model_v3 = Resources.asFile("deeplearning4j-nlp/seqv_beta3.csv");
|
||||||
|
File model_v4 = Resources.asFile("deeplearning4j-nlp/seqv_beta4.csv");
|
||||||
|
try {
|
||||||
|
SequenceVectors vectors1 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v3);
|
||||||
|
SequenceVectors vectors2 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v4);
|
||||||
|
|
||||||
|
assertEquals(vectors1.vocab().numWords(), vectors2.vocab().numWords());
|
||||||
|
for (int i = 0; i < vectors1.vocab().numWords(); ++i) {
|
||||||
|
assertEquals(vectors1.vocab().words().toArray()[i], vectors2.vocab().words().toArray()[i]);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
fail(e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -249,7 +249,7 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
}
|
}
|
||||||
l[0] = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, numClasses);
|
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
|
||||||
for( int i=0; i<mb; i++ ){
|
for( int i=0; i<mb; i++ ){
|
||||||
l[0].putScalar(i, classLabels[i], 1.0);
|
l[0].putScalar(i, classLabels[i], 1.0);
|
||||||
}
|
}
|
||||||
|
@ -277,9 +277,9 @@ public class BertIterator implements MultiDataSetIterator {
|
||||||
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
|
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
|
||||||
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
|
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
|
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
|
||||||
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, vocabSize, outLength);
|
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
|
||||||
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
|
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
|
||||||
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), outLength, mbPadded, vocabSize);
|
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
|
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
|
||||||
}
|
}
|
||||||
|
|
|
@ -201,7 +201,7 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
|
||||||
List<String> tokens = new ArrayList<>();
|
List<String> tokens = new ArrayList<>();
|
||||||
while (t.hasMoreTokens()) {
|
while (t.hasMoreTokens()) {
|
||||||
String token = t.nextToken();
|
String token = t.nextToken();
|
||||||
if (!wordVectors.hasWord(token)) {
|
if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) {
|
||||||
switch (unknownWordHandling) {
|
switch (unknownWordHandling) {
|
||||||
case RemoveWord:
|
case RemoveWord:
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -1312,10 +1312,12 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
|
||||||
int rest = batchSequences.size() % batchSize;
|
int rest = batchSequences.size() % batchSize;
|
||||||
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
|
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
|
||||||
for (int j = 0; j < chunks; ++j) {
|
for (int j = 0; j < chunks; ++j) {
|
||||||
|
if (trainElementsVectors) {
|
||||||
if (elementsLearningAlgorithm instanceof SkipGram)
|
if (elementsLearningAlgorithm instanceof SkipGram)
|
||||||
((SkipGram)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||||
else if (elementsLearningAlgorithm instanceof CBOW)
|
else if (elementsLearningAlgorithm instanceof CBOW)
|
||||||
((CBOW)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
|
||||||
|
}
|
||||||
|
|
||||||
if (trainSequenceVectors) {
|
if (trainSequenceVectors) {
|
||||||
if (sequenceLearningAlgorithm instanceof DBOW)
|
if (sequenceLearningAlgorithm instanceof DBOW)
|
||||||
|
|
|
@ -32,7 +32,7 @@ import java.io.Serializable;
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = VocabWord.class)
|
||||||
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
|
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
|
||||||
setterVisibility = JsonAutoDetect.Visibility.NONE)
|
setterVisibility = JsonAutoDetect.Visibility.NONE)
|
||||||
public class VocabWord extends SequenceElement implements Serializable {
|
public class VocabWord extends SequenceElement implements Serializable {
|
||||||
|
|
|
@ -224,6 +224,7 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testMinibatchPadding() throws Exception {
|
public void testMinibatchPadding() throws Exception {
|
||||||
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
String toTokenize1 = "I saw a girl with a telescope.";
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.api;
|
package org.deeplearning4j.nn.api;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
|
@ -73,4 +74,6 @@ public interface TrainingConfig {
|
||||||
*/
|
*/
|
||||||
double getGradientNormalizationThreshold();
|
double getGradientNormalizationThreshold();
|
||||||
|
|
||||||
|
void setDataType(DataType dataType);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,4 +93,9 @@ public abstract class GraphVertex implements Cloneable, Serializable {
|
||||||
*/
|
*/
|
||||||
public abstract MemoryReport getMemoryReport(InputType... inputTypes);
|
public abstract MemoryReport getMemoryReport(InputType... inputTypes);
|
||||||
|
|
||||||
|
|
||||||
|
public void setDataType(DataType dataType) {
|
||||||
|
//No-op for most layers
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -146,4 +146,9 @@ public class LayerVertex extends GraphVertex {
|
||||||
//TODO preprocessor memory
|
//TODO preprocessor memory
|
||||||
return layerConf.getLayer().getMemoryReport(it);
|
return layerConf.getLayer().getMemoryReport(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDataType(DataType dataType){
|
||||||
|
layerConf.getLayer().setDataType(dataType);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -223,6 +223,11 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
|
||||||
"Not supported: all layers with parameters should override this method");
|
"Not supported: all layers with parameters should override this method");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDataType(DataType dataType) {
|
||||||
|
//No-op for most layers
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is a report of the estimated memory consumption for the given layer
|
* This is a report of the estimated memory consumption for the given layer
|
||||||
*
|
*
|
||||||
|
|
|
@ -96,7 +96,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
||||||
|
|
||||||
if (!map.containsKey(inputNum)) {
|
if (!map.containsKey(inputNum)) {
|
||||||
//Lazily define extra input variable as required
|
//Lazily define extra input variable as required
|
||||||
SDVariable var = sameDiff.var("var_" + inputNum, 1); //TODO is this shape safe?
|
SDVariable var = sameDiff.var("var_" + inputNum, dataType, -1); //TODO is this shape safe?
|
||||||
map.put(inputNum, var);
|
map.put(inputNum, var);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
||||||
protected IUpdater biasUpdater;
|
protected IUpdater biasUpdater;
|
||||||
protected GradientNormalization gradientNormalization;
|
protected GradientNormalization gradientNormalization;
|
||||||
protected double gradientNormalizationThreshold = Double.NaN;
|
protected double gradientNormalizationThreshold = Double.NaN;
|
||||||
|
protected DataType dataType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define the vertex
|
* Define the vertex
|
||||||
|
@ -234,4 +235,9 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
|
||||||
public double getGradientNormalizationThreshold() {
|
public double getGradientNormalizationThreshold() {
|
||||||
return gradientNormalizationThreshold;
|
return gradientNormalizationThreshold;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDataType(DataType dataType) {
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.misc;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.nn.api.TrainingConfig;
|
import org.deeplearning4j.nn.api.TrainingConfig;
|
||||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
@ -63,4 +64,9 @@ public class DummyConfig implements TrainingConfig {
|
||||||
public double getGradientNormalizationThreshold() {
|
public double getGradientNormalizationThreshold() {
|
||||||
return 1.0;
|
return 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setDataType(DataType dataType) {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -512,6 +512,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
for(; i<topologicalOrder.length; i++ ){
|
for(; i<topologicalOrder.length; i++ ){
|
||||||
String name = indices.getIdxToName().get(i);
|
String name = indices.getIdxToName().get(i);
|
||||||
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
|
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
|
||||||
|
n.setDataType(netDtype);
|
||||||
numParamsForVertex[i] = n.numParams(true);
|
numParamsForVertex[i] = n.numParams(true);
|
||||||
numParams += numParamsForVertex[i];
|
numParams += numParamsForVertex[i];
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.BaseLayer;
|
import org.deeplearning4j.nn.layers.BaseLayer;
|
||||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -60,10 +62,16 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
|
||||||
assertInputSet(true);
|
assertInputSet(true);
|
||||||
if (input.rank() != 3)
|
if (input.rank() != 3)
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Input is not rank 3. Got input with rank " + input.rank() + " " + layerId());
|
"Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
|
||||||
|
input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId());
|
||||||
if (labels == null)
|
if (labels == null)
|
||||||
throw new IllegalStateException("Labels are not set (null)");
|
throw new IllegalStateException("Labels are not set (null)");
|
||||||
|
|
||||||
|
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||||
|
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||||
|
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||||
|
|
||||||
|
|
||||||
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||||
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||||
INDArray maskReshaped;
|
INDArray maskReshaped;
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.layers.BaseOutputLayer;
|
import org.deeplearning4j.nn.layers.BaseOutputLayer;
|
||||||
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
import org.deeplearning4j.nn.params.DefaultParamInitializer;
|
||||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -57,8 +58,13 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
||||||
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
|
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
|
||||||
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
|
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
|
||||||
}
|
}
|
||||||
|
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||||
|
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||||
|
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||||
|
|
||||||
INDArray inputTemp = input;
|
INDArray inputTemp = input;
|
||||||
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
|
||||||
|
|
||||||
Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout
|
Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout
|
||||||
this.input = inputTemp;
|
this.input = inputTemp;
|
||||||
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
INDArray epsilon2d = gradAndEpsilonNext.getSecond();
|
||||||
|
|
|
@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of a SameDiff graph vertex.
|
* Implementation of a SameDiff graph vertex.
|
||||||
|
@ -96,12 +94,11 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
|
||||||
// sameDiff.clearExecutionCache();
|
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
for(int i=0; i<inputs.length; i++ ){
|
for(int i=0; i<inputs.length; i++ ){
|
||||||
String name = config.getVertexParams().getInputs().get(i);
|
String name = config.getVertexParams().getInputs().get(i);
|
||||||
|
@ -121,6 +118,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
||||||
INDArray result = out.get(outputKey);
|
INDArray result = out.get(outputKey);
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -131,27 +132,42 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
|
|
||||||
INDArray[] dLdIns;
|
INDArray[] dLdIns;
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
// sameDiff.clearExecutionCache();
|
if(sameDiff == null){
|
||||||
|
doInit();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
|
List<String> inputs = config.getVertexParams().getInputs();
|
||||||
|
String[] inArr = inputs.toArray(new String[inputs.size()]);
|
||||||
|
sameDiff.createGradFunction(inArr);
|
||||||
|
}
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
//Set inputs
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
for(int i=0; i<inputs.length; i++ ){
|
List<String> inputs = config.getVertexParams().getInputs();
|
||||||
String name = config.getVertexParams().getInputs().get(i);
|
int i=0;
|
||||||
|
for(String s : inputs){
|
||||||
|
phMap.put(s, this.inputs[i++]);
|
||||||
|
}
|
||||||
|
if(maskArrays != null){
|
||||||
|
for( int j=0; j<maskArrays.length; j++ ){
|
||||||
|
String name = inputs.get(j);
|
||||||
final String maskName = name + "_mask";
|
final String maskName = name + "_mask";
|
||||||
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
|
if(maskArrays[j] != null) {
|
||||||
if(maskArrays != null && maskArrays[i] != null) {
|
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
|
||||||
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
|
|
||||||
}else{
|
|
||||||
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
|
}
|
||||||
|
String epsName = fn.getGradPlaceholderName();
|
||||||
|
phMap.put(epsName, epsilon);
|
||||||
|
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
//TODO this should only be necessary, in theory, once!
|
//TODO this should only be necessary, in theory, once!
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
sameDiff.execBackwards(null);
|
sameDiff.execBackwards(phMap);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
|
@ -159,10 +175,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
g.gradientForVariable().put(s, dl4jGrad);
|
g.gradientForVariable().put(s, dl4jGrad);
|
||||||
}
|
}
|
||||||
|
|
||||||
dLdIns = new INDArray[inputs.length];
|
dLdIns = new INDArray[inputs.size()];
|
||||||
for(int i=0; i<inputs.length; i++ ){
|
for(int j=0; j<inputs.size(); j++ ){
|
||||||
String name = config.getVertexParams().getInputs().get(i);
|
String name = inputs.get(j);
|
||||||
dLdIns[i] = sameDiff.grad(name).getArr();
|
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,6 +187,9 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
return new Pair<>(g, dLdIns);
|
return new Pair<>(g, dLdIns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -78,25 +79,32 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
@Override
|
@Override
|
||||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(false);
|
assertInputSet(false);
|
||||||
|
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
|
||||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||||
bl.validateInput(input);
|
bl.validateInput(input);
|
||||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
|
||||||
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
phMap.put(INPUT_KEY, input);
|
||||||
if(maskArray != null){
|
if(maskArray != null){
|
||||||
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
|
phMap.put(MASK_KEY, maskArray);
|
||||||
}else{
|
|
||||||
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ) {
|
for(String s : paramTable.keySet() ) {
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
|
||||||
INDArray result = out.get(outputKey);
|
INDArray result = out.get(outputKey);
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,24 +118,36 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
|
|
||||||
INDArray dLdIn;
|
INDArray dLdIn;
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
// sameDiff.clearExecutionCache();
|
if(sameDiff == null){
|
||||||
|
doInit();
|
||||||
|
}
|
||||||
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||||
bl.validateInput(input);
|
bl.validateInput(input);
|
||||||
|
|
||||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
|
||||||
if(maskArray != null){
|
|
||||||
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
|
|
||||||
}else{
|
|
||||||
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
|
|
||||||
}
|
|
||||||
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
|
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
//TODO this should only be necessary, in theory, once!
|
//TODO this should only be necessary, in theory, once!
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
phMap.put(INPUT_KEY, input);
|
||||||
|
phMap.put(fn.getGradPlaceholderName(), epsilon);
|
||||||
|
if(maskArray != null){
|
||||||
|
phMap.put(MASK_KEY, maskArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
|
||||||
|
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||||
|
for(String s : paramTable.keySet()){
|
||||||
|
requiredGrads.add(sameDiff.grad(s).getVarName());
|
||||||
|
}
|
||||||
|
|
||||||
|
sameDiff.execBackwards(phMap, requiredGrads);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
|
@ -138,6 +158,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
|
System.out.println(dLdIn);
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,8 +250,9 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
sameDiff = SameDiff.create();
|
sameDiff = SameDiff.create();
|
||||||
Map<String, INDArray> p = paramTable();
|
Map<String, INDArray> p = paramTable();
|
||||||
|
|
||||||
val inputShape = input.shape().clone();
|
long[] inputShape = input.shape().clone();
|
||||||
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
|
inputShape[0] = -1;
|
||||||
|
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
|
||||||
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
||||||
Map<String, SDVariable> params = new LinkedHashMap<>();
|
Map<String, SDVariable> params = new LinkedHashMap<>();
|
||||||
for (String s : paramShapes.keySet()) {
|
for (String s : paramShapes.keySet()) {
|
||||||
|
@ -235,7 +261,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
params.put(s, v);
|
params.put(s, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(dataType, inputShape));
|
long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, -1);
|
||||||
|
SDVariable mask = sameDiff.placeHolder(MASK_KEY, dataType, maskShape);
|
||||||
|
|
||||||
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask);
|
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask);
|
||||||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||||
|
|
|
@ -87,35 +87,43 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){
|
private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){
|
||||||
assertInputSet(false);
|
assertInputSet(false);
|
||||||
|
|
||||||
//Check where the output occors. If it's a simple loss layer (no params) this could
|
//Check where the output occurs. If it's a simple loss layer (no params) this could
|
||||||
// just be the input!
|
// just be the input!
|
||||||
if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){
|
if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO optimize
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO optimize
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
|
||||||
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
|
|
||||||
if(layerConf().labelsRequired() && labels != null) {
|
|
||||||
sameDiff.associateArrayWithVariable(labels.dup(), sameDiff.getVariable(LABELS_KEY));
|
|
||||||
}
|
|
||||||
for(String s : paramTable.keySet() ) {
|
for(String s : paramTable.keySet() ) {
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray score = sameDiff.execAndEndResult();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
phMap.put(INPUT_KEY, input);
|
||||||
|
if(!activations && layerConf().labelsRequired() && labels != null) {
|
||||||
|
phMap.put(LABELS_KEY, labels);
|
||||||
|
}
|
||||||
|
|
||||||
|
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
||||||
|
|
||||||
|
INDArray out = sameDiff.execSingle(phMap, s);
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
if(activations) {
|
if(activations) {
|
||||||
INDArray result = sameDiff.getArrForVarName(layerConf().activationsVertexName());
|
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
|
||||||
Preconditions.checkNotNull(result, "Activations (result) array for variable \"%s\" was " +
|
|
||||||
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
|
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
|
||||||
"does not exist", layerConf().activationsVertexName());
|
"does not exist", layerConf().activationsVertexName());
|
||||||
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
|
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
|
||||||
} else {
|
} else {
|
||||||
return score;
|
return out;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -127,23 +135,26 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " +
|
Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " +
|
||||||
"If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" +
|
"If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" +
|
||||||
" to return false instead");
|
" to return false instead");
|
||||||
|
Gradient g = new DefaultGradient();
|
||||||
|
|
||||||
|
INDArray dLdIn;
|
||||||
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||||
if(sameDiff == null){
|
if(sameDiff == null){
|
||||||
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
//Usually doInit will be called in forward pass; not necessarily the case in output layers
|
||||||
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
|
}
|
||||||
|
|
||||||
Gradient g = new DefaultGradient();
|
INDArray castInput = input.castTo(dataType);
|
||||||
|
|
||||||
INDArray dLdIn;
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
|
||||||
INDArray castInput = input.castTo(Nd4j.defaultFloatingPointType());
|
|
||||||
if(castInput.isAttached())
|
if(castInput.isAttached())
|
||||||
castInput = castInput.dup();
|
castInput = castInput.dup();
|
||||||
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
|
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
|
||||||
if(layerConf().labelsRequired()) {
|
if(layerConf().labelsRequired()) {
|
||||||
INDArray castLabels = labels.castTo(Nd4j.defaultFloatingPointType());
|
INDArray castLabels = labels.castTo(dataType);
|
||||||
if(castLabels.isAttached())
|
if(castLabels.isAttached())
|
||||||
castLabels = castLabels.dup();
|
castLabels = castLabels.dup();
|
||||||
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
|
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
|
||||||
|
@ -154,7 +165,17 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||||
}
|
}
|
||||||
|
|
||||||
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
|
List<String> gradVarNames = new ArrayList<>();
|
||||||
|
for(String s : paramTable.keySet()){
|
||||||
|
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
|
||||||
|
}
|
||||||
|
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||||
|
|
||||||
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
phMap.put(INPUT_KEY, input);
|
||||||
|
phMap.put(LABELS_KEY, labels);
|
||||||
|
|
||||||
|
sameDiff.execBackwards(phMap, gradVarNames);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
|
@ -165,6 +186,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
sameDiff.clearPlaceholders(true);
|
||||||
|
sameDiff.clearOpInputs();
|
||||||
|
|
||||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,18 +277,20 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
sameDiff = SameDiff.create();
|
sameDiff = SameDiff.create();
|
||||||
Map<String, INDArray> p = paramTable();
|
Map<String, INDArray> p = paramTable();
|
||||||
|
|
||||||
val inputShape = input.shape().clone();
|
long[] inputShape = input.shape().clone();
|
||||||
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
|
inputShape[0] = -1;
|
||||||
|
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
|
||||||
SDVariable labelVar = null;
|
SDVariable labelVar = null;
|
||||||
if(layerConf().labelsRequired()){
|
if(layerConf().labelsRequired()){
|
||||||
long[] labelShape = labels == null ? new long[]{1} : labels.shape().clone();
|
long[] labelShape = labels == null ? new long[]{-1, -1} : labels.shape().clone();
|
||||||
labelVar = sameDiff.var(LABELS_KEY, dataType, labelShape);
|
labelShape[0] = -1;
|
||||||
|
labelVar = sameDiff.placeHolder(LABELS_KEY, dataType, labelShape);
|
||||||
}
|
}
|
||||||
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
|
||||||
Map<String, SDVariable> params = new LinkedHashMap<>();
|
Map<String, SDVariable> params = new LinkedHashMap<>();
|
||||||
for (String s : paramShapes.keySet()) {
|
for (String s : paramShapes.keySet()) {
|
||||||
val ps = paramShapes.get(s);
|
val ps = paramShapes.get(s);
|
||||||
SDVariable v = sameDiff.var(s, ps);
|
SDVariable v = sameDiff.var(s, dataType, ps);
|
||||||
params.put(s, v);
|
params.put(s, v);
|
||||||
}
|
}
|
||||||
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params);
|
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params);
|
||||||
|
|
|
@ -660,6 +660,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
val nParamsPerLayer = new long[nLayers];
|
val nParamsPerLayer = new long[nLayers];
|
||||||
for (int i = 0; i < nLayers; i++) {
|
for (int i = 0; i < nLayers; i++) {
|
||||||
NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i);
|
NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i);
|
||||||
|
conf.getLayer().setDataType(netDtype);
|
||||||
nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
|
nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
|
||||||
paramLength += nParamsPerLayer[i];
|
paramLength += nParamsPerLayer[i];
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,7 +152,7 @@ public class HardwareMetric implements Serializable {
|
||||||
return builder.logicalProcessorCount(processor.getLogicalProcessorCount())
|
return builder.logicalProcessorCount(processor.getLogicalProcessorCount())
|
||||||
.physicalProcessorCount(processor.getPhysicalProcessorCount())
|
.physicalProcessorCount(processor.getPhysicalProcessorCount())
|
||||||
.name(name)
|
.name(name)
|
||||||
.averagedCpuLoad((long) processor.getSystemCpuLoad() * 100)
|
.averagedCpuLoad((long)(processor.getSystemCpuLoad() * 100))
|
||||||
.ioWaitTime(iowait).gpuMetrics(gpuMetric)
|
.ioWaitTime(iowait).gpuMetrics(gpuMetric)
|
||||||
.hostName(networkParams.getHostName()).diskInfo(diskInfoMap)
|
.hostName(networkParams.getHostName()).diskInfo(diskInfoMap)
|
||||||
.currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable())
|
.currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable())
|
||||||
|
|
|
@ -48,8 +48,6 @@ if(WIN32)
|
||||||
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
|
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if ("${LIBND4J_ALL_OPS}")
|
if ("${LIBND4J_ALL_OPS}")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
|
||||||
else()
|
else()
|
||||||
|
@ -234,21 +232,21 @@ if(CUDA_BLAS)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT BUILD_TESTS)
|
|
||||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
||||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/*.cpp ../include/execution/*.h)
|
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
||||||
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
||||||
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
||||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
|
||||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu ../include/ops/declarable/helpers/impl/*.cpp)
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||||
|
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/cuda/*.cu ../include/helpers/*.h)
|
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||||
|
|
||||||
|
if (NOT BUILD_TESTS)
|
||||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||||
|
@ -258,20 +256,6 @@ if(CUDA_BLAS)
|
||||||
else()
|
else()
|
||||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
||||||
|
|
||||||
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
|
|
||||||
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
|
|
||||||
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
|
|
||||||
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
|
|
||||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
|
|
||||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
|
|
||||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
|
||||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu)
|
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
|
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
|
||||||
|
|
||||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||||
|
@ -308,7 +292,7 @@ elseif(CPU_BLAS)
|
||||||
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
|
||||||
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
|
||||||
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
|
||||||
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
|
||||||
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)
|
||||||
|
|
|
@ -372,8 +372,8 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* if _bufferD==nullptr return _buffer, else return _bufferD
|
* if _bufferD==nullptr return _buffer, else return _bufferD
|
||||||
*/
|
*/
|
||||||
FORCEINLINE void* specialBuffer();
|
void* specialBuffer();
|
||||||
FORCEINLINE void* getSpecialBuffer() const;
|
void* getSpecialBuffer() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns device buffer if compilation is for cuda case, otherwise returns host buffer
|
* returns device buffer if compilation is for cuda case, otherwise returns host buffer
|
||||||
|
@ -429,16 +429,16 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array
|
* permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array
|
||||||
*/
|
*/
|
||||||
NDArray* permute(const std::initializer_list<int>& dimensions) const;
|
NDArray permute(const std::initializer_list<int>& dimensions) const;
|
||||||
NDArray* permute(const std::vector<int>& dimensions) const;
|
NDArray permute(const std::vector<int>& dimensions) const;
|
||||||
NDArray* permute(const int* dimensions, const int rank) const;
|
NDArray permute(const int* dimensions, const int rank) const;
|
||||||
|
|
||||||
void permute(const int* dimensions, const int rank, NDArray& target) const;
|
void permute(const int* dimensions, const int rank, NDArray& target) const;
|
||||||
void permute(const std::vector<int>& dimensions, NDArray& target) const;
|
void permute(const std::vector<int>& dimensions, NDArray& target) const;
|
||||||
|
|
||||||
NDArray* permute(const std::initializer_list<Nd4jLong>& dimensions) const;
|
NDArray permute(const std::initializer_list<Nd4jLong>& dimensions) const;
|
||||||
NDArray* permute(const std::vector<Nd4jLong>& dimensions) const;
|
NDArray permute(const std::vector<Nd4jLong>& dimensions) const;
|
||||||
NDArray* permute(const Nd4jLong* dimensions, const int rank) const;
|
NDArray permute(const Nd4jLong* dimensions, const int rank) const;
|
||||||
|
|
||||||
void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const;
|
void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const;
|
||||||
void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const;
|
void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const;
|
||||||
|
@ -508,7 +508,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* returns new copy of this array, optionally in different order
|
* returns new copy of this array, optionally in different order
|
||||||
*/
|
*/
|
||||||
NDArray *dup(const char newOrder = 'a');
|
NDArray *dup(const char newOrder = 'a') const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns sum of all elements of array
|
* returns sum of all elements of array
|
||||||
|
@ -687,7 +687,7 @@ namespace nd4j {
|
||||||
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||||
|
|
||||||
|
|
||||||
#if defined(__CUDABLAS__) && defined(BUILD_TESTS)
|
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
|
||||||
template <typename Lambda>
|
template <typename Lambda>
|
||||||
FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr);
|
FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr);
|
||||||
|
|
||||||
|
@ -790,8 +790,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* apply transpose operation to the copy of this array, that is this array remains unaffected
|
* apply transpose operation to the copy of this array, that is this array remains unaffected
|
||||||
*/
|
*/
|
||||||
NDArray* transpose() const;
|
NDArray transpose() const;
|
||||||
NDArray transp() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* perform transpose operation and store result in target, this array remains unaffected
|
* perform transpose operation and store result in target, this array remains unaffected
|
||||||
|
@ -915,7 +914,7 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||||
*/
|
*/
|
||||||
NDArray* reshape(const char order, const std::vector<Nd4jLong>& shape) const;
|
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* calculate strides and set given order
|
* calculate strides and set given order
|
||||||
|
@ -2093,15 +2092,6 @@ Nd4jLong* NDArray::shapeInfo() {
|
||||||
return _shapeInfo;
|
return _shapeInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void* NDArray::specialBuffer() {
|
|
||||||
|
|
||||||
if (_buffer->special() == nullptr)
|
|
||||||
return getBuffer();
|
|
||||||
// FIXME: this should be fixed once CUDA backend added
|
|
||||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
Nd4jLong* NDArray::specialShapeInfo() {
|
Nd4jLong* NDArray::specialShapeInfo() {
|
||||||
if (_shapeInfoD == nullptr)
|
if (_shapeInfoD == nullptr)
|
||||||
|
@ -2110,14 +2100,6 @@ Nd4jLong* NDArray::specialShapeInfo() {
|
||||||
return _shapeInfoD;
|
return _shapeInfoD;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
void* NDArray::getSpecialBuffer() const {
|
|
||||||
if (_buffer->special() == nullptr)
|
|
||||||
return getBuffer();
|
|
||||||
// FIXME: this should be fixed once CUDA backend added
|
|
||||||
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
Nd4jLong NDArray::getBufferOffset() const {
|
Nd4jLong NDArray::getBufferOffset() const {
|
||||||
return _offset;
|
return _offset;
|
||||||
|
@ -2137,7 +2119,7 @@ Nd4jLong* NDArray::getSpecialShapeInfo() const{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#if defined(__CUDACC__) && defined(BUILD_TESTS)
|
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
|
||||||
// for CUDA we need stil stuff inline
|
// for CUDA we need stil stuff inline
|
||||||
#include "cuda/NDArrayLambda.hpp"
|
#include "cuda/NDArrayLambda.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -39,9 +39,9 @@ NDArray* NDArray::asT() const{
|
||||||
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||||
auto l = this->lengthOf();
|
auto l = this->lengthOf();
|
||||||
|
|
||||||
prepareSpecialUse({result}, {this});
|
NDArray::prepareSpecialUse({result}, {this});
|
||||||
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({result}, {this});
|
NDArray::registerSpecialUse({result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -583,117 +583,130 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop
|
||||||
void NDArray::assign(const double value) {
|
void NDArray::assign(const double value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const float value) {
|
void NDArray::assign(const float value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const float16 value) {
|
void NDArray::assign(const float16 value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const bfloat16& value) {
|
void NDArray::assign(const bfloat16& value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const Nd4jLong value) {
|
void NDArray::assign(const Nd4jLong value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int value) {
|
void NDArray::assign(const int value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int16_t value) {
|
void NDArray::assign(const int16_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint8_t value) {
|
void NDArray::assign(const uint8_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint16_t value) {
|
void NDArray::assign(const uint16_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint32_t value) {
|
void NDArray::assign(const uint32_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint64_t value) {
|
void NDArray::assign(const uint64_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int8_t value) {
|
void NDArray::assign(const int8_t value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const bool value) {
|
void NDArray::assign(const bool value) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
prepareSpecialUse({this}, {&temp});
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -716,9 +729,9 @@ NDArray NDArray::varianceNumber(nd4j::variance::Ops op, bool biasCorrected) {
|
||||||
|
|
||||||
NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext());
|
NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&res}, {this});
|
NDArray::prepareSpecialUse({&res}, {this});
|
||||||
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected);
|
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected);
|
||||||
registerSpecialUse({&res}, {this});
|
NDArray::registerSpecialUse({&res}, {this});
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -918,9 +931,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::FloatOps op, void *extraParams) cons
|
||||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
|
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
|
||||||
NDArray result(shape, true, this->getContext());
|
NDArray result(shape, true, this->getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -932,9 +945,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::SameOps op, void *extraParams) const
|
||||||
|
|
||||||
NDArray result(dataType(), getContext());
|
NDArray result(dataType(), getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -947,9 +960,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::BoolOps op, void *extraParams) const
|
||||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
|
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
|
||||||
NDArray result(shape, true, this->getContext());
|
NDArray result(shape, true, this->getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -962,9 +975,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
|
||||||
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
|
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
|
||||||
NDArray result(shape, true, this->getContext());
|
NDArray result(shape, true, this->getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -976,9 +989,9 @@ void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *ext
|
||||||
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
|
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
|
||||||
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
|
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
|
||||||
|
|
||||||
prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
|
||||||
registerSpecialUse({&target}, {this});
|
NDArray::registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -989,9 +1002,9 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
|
||||||
if(!target.isScalar() || target.dataType() != dataType())
|
if(!target.isScalar() || target.dataType() != dataType())
|
||||||
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
|
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
|
||||||
|
|
||||||
prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||||
registerSpecialUse({&target}, {this});
|
NDArray::registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1002,9 +1015,9 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
|
||||||
if(!target.isScalar() || target.dataType() != DataType::BOOL)
|
if(!target.isScalar() || target.dataType() != DataType::BOOL)
|
||||||
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
|
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
|
||||||
|
|
||||||
prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||||
registerSpecialUse({&target}, {this});
|
NDArray::registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1015,9 +1028,9 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
|
||||||
if(!target.isScalar() || target.dataType() != DataType::INT64)
|
if(!target.isScalar() || target.dataType() != DataType::INT64)
|
||||||
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
|
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
|
||||||
|
|
||||||
prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
|
||||||
registerSpecialUse({&target}, {this});
|
NDArray::registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1027,9 +1040,9 @@ NDArray NDArray::indexReduceNumber(nd4j::indexreduce::Ops op, ExtraArguments *ex
|
||||||
|
|
||||||
auto res = NDArrayFactory::create<Nd4jLong>(0);
|
auto res = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&res}, {this});
|
NDArray::NDArray::prepareSpecialUse({&res}, {this});
|
||||||
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo());
|
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo());
|
||||||
NDArray::registerSpecialUse({&res}, {this});
|
NDArray::NDArray::registerSpecialUse({&res}, {this});
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -1240,17 +1253,10 @@ BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4j
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
|
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
|
||||||
NDArray* NDArray::transpose() const {
|
NDArray NDArray::transpose() const {
|
||||||
auto newArr = new NDArray(getBuffer(), getSpecialBuffer(), getShapeInfo(), getContext(), false, false);
|
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||||
newArr->transposei();
|
|
||||||
|
|
||||||
return newArr;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
NDArray NDArray::transp() const {
|
|
||||||
NDArray newArr(getBuffer(), getShapeInfo(), getContext(), false);
|
|
||||||
newArr.transposei();
|
newArr.transposei();
|
||||||
|
|
||||||
return newArr;
|
return newArr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1360,10 +1366,10 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||||
NDArray* NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
|
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
|
||||||
|
|
||||||
auto newArr = new NDArray(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext());
|
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||||
newArr->reshapei(order, shape);
|
newArr.reshapei(order, shape);
|
||||||
|
|
||||||
return newArr;
|
return newArr;
|
||||||
}
|
}
|
||||||
|
@ -1420,43 +1426,43 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const int* dimensions, const int rank) const {
|
NDArray NDArray::permute(const int* dimensions, const int rank) const {
|
||||||
|
|
||||||
// evaluate shapeInfo for output (permuted) array ret
|
// evaluate shapeInfo for output (permuted) array ret
|
||||||
auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace());
|
auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace());
|
||||||
auto ret = new NDArray(_buffer, ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
|
NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
|
||||||
ret->_isView = true;
|
ret._isView = true;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
|
NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
|
||||||
int tempDims[MAX_RANK];
|
int tempDims[MAX_RANK];
|
||||||
shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank);
|
shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank);
|
||||||
return permute(tempDims, rank);
|
return permute(tempDims, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const std::vector<int>& dimensions) const {
|
NDArray NDArray::permute(const std::vector<int>& dimensions) const {
|
||||||
auto data = dimensions.data();
|
auto data = dimensions.data();
|
||||||
auto size = dimensions.size();
|
auto size = dimensions.size();
|
||||||
return permute(data, size);
|
return permute(data, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
|
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
|
||||||
return permute(dimensions.data(), dimensions.size());
|
return permute(dimensions.data(), dimensions.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const std::initializer_list<int>& dimensions) const {
|
NDArray NDArray::permute(const std::initializer_list<int>& dimensions) const {
|
||||||
std::vector<int> vec(dimensions);
|
std::vector<int> vec(dimensions);
|
||||||
return permute(vec);
|
return permute(vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
|
NDArray NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
|
||||||
std::vector<Nd4jLong> vec(dimensions);
|
std::vector<Nd4jLong> vec(dimensions);
|
||||||
return permute(vec);
|
return permute(vec);
|
||||||
}
|
}
|
||||||
|
@ -1528,10 +1534,9 @@ bool NDArray::isUnitary() {
|
||||||
throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !");
|
throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !");
|
||||||
|
|
||||||
auto tr = this->transpose();
|
auto tr = this->transpose();
|
||||||
auto trMul = MmulHelper::mmul(this, tr, nullptr, 1.f, 0.f);
|
auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f);
|
||||||
|
|
||||||
bool result = trMul->isIdentityMatrix();
|
bool result = trMul->isIdentityMatrix();
|
||||||
delete tr;
|
|
||||||
delete trMul;
|
delete trMul;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -1777,11 +1782,11 @@ NDArray NDArray::operator*(const T& scalar) const {
|
||||||
|
|
||||||
auto tmp = NDArrayFactory::create(dataType(), scalar, getContext());
|
auto tmp = NDArrayFactory::create(dataType(), scalar, getContext());
|
||||||
NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext());
|
NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
||||||
|
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
template NDArray NDArray::operator*(const double& scalar) const;
|
template NDArray NDArray::operator*(const double& scalar) const;
|
||||||
|
@ -1811,6 +1816,7 @@ NDArray NDArray::operator/(const T& scalar) const {
|
||||||
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
NDArray::prepareSpecialUse({&result}, {this, &tmp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
NDArray::registerSpecialUse({&result}, {this, &tmp});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
template NDArray NDArray::operator/(const double& scalar) const;
|
template NDArray NDArray::operator/(const double& scalar) const;
|
||||||
|
@ -2050,14 +2056,14 @@ void NDArray::operator+=(const NDArray& other) {
|
||||||
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (!this->isScalar() && other.isScalar()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
Nd4jLong *bShape = nullptr;
|
Nd4jLong *bShape = nullptr;
|
||||||
|
@ -2084,14 +2090,14 @@ void NDArray::operator-=(const NDArray& other) {
|
||||||
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (!this->isScalar() && other.isScalar()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
Nd4jLong *bShape = nullptr;
|
Nd4jLong *bShape = nullptr;
|
||||||
|
@ -2117,14 +2123,14 @@ void NDArray::operator*=(const NDArray& other) {
|
||||||
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (!this->isScalar() && other.isScalar()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
Nd4jLong *bShape = nullptr;
|
Nd4jLong *bShape = nullptr;
|
||||||
|
@ -2154,14 +2160,14 @@ void NDArray::operator/=(const NDArray& other) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (!this->isScalar() && other.isScalar()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
Nd4jLong *bShape = nullptr;
|
Nd4jLong *bShape = nullptr;
|
||||||
|
@ -2264,9 +2270,9 @@ NDArray NDArray::operator-(const NDArray& other) const {
|
||||||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -2285,9 +2291,9 @@ NDArray NDArray::operator*(const NDArray& other) const {
|
||||||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -2308,9 +2314,9 @@ NDArray NDArray::operator/(const NDArray& other) const {
|
||||||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -2326,9 +2332,9 @@ NDArray NDArray::operator-() const {
|
||||||
|
|
||||||
NDArray result(getShapeInfo(), false, getContext());
|
NDArray result(getShapeInfo(), false, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -2631,7 +2637,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& di
|
||||||
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
||||||
NDArray::prepareSpecialUse({result}, {this, other});
|
NDArray::prepareSpecialUse({result}, {this, other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({result}, {this, other});
|
NDArray::registerSpecialUse({result}, {this, other});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2688,7 +2694,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
||||||
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
||||||
NDArray::prepareSpecialUse({result}, {this, other});
|
NDArray::prepareSpecialUse({result}, {this, other});
|
||||||
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({result}, {this, other});
|
NDArray::registerSpecialUse({result}, {this, other});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2896,7 +2902,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||||
Nd4jLong *shapeInfoNew;
|
Nd4jLong *shapeInfoNew;
|
||||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
|
|
||||||
bool canReshape = shape::reshapeC(this->rankOf(), this->_shapeInfo, shape.size(), shape.data(), shapeInfoNew);
|
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
||||||
|
|
||||||
// we can do this only if there was no permute applied, or there are no weird strides
|
// we can do this only if there was no permute applied, or there are no weird strides
|
||||||
if (canReshape) {
|
if (canReshape) {
|
||||||
|
@ -2948,11 +2954,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* othe
|
||||||
if (target->dataType() != this->dataType() && target->dataType() != other->dataType())
|
if (target->dataType() != this->dataType() && target->dataType() != other->dataType())
|
||||||
throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !");
|
throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this, other});
|
NDArray::prepareSpecialUse({target}, {this, other});
|
||||||
|
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||||
|
NDArray::registerSpecialUse({target}, {this, other});
|
||||||
registerSpecialUse({target}, {this, other});
|
|
||||||
|
|
||||||
if (extraParams != nullptr)
|
if (extraParams != nullptr)
|
||||||
synchronize("NDArray::applyPairwiseTransform");
|
synchronize("NDArray::applyPairwiseTransform");
|
||||||
|
@ -2969,9 +2973,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
|
||||||
if (dataType() != other->dataType())
|
if (dataType() != other->dataType())
|
||||||
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
|
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this, other});
|
NDArray::prepareSpecialUse({target}, {this, other});
|
||||||
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||||
registerSpecialUse({target}, {this, other});
|
NDArray::registerSpecialUse({target}, {this, other});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3070,22 +3074,23 @@ void NDArray::assign(const NDArray& other) {
|
||||||
if (other.isScalar()) {
|
if (other.isScalar()) {
|
||||||
|
|
||||||
if(this->isScalar()) {
|
if(this->isScalar()) {
|
||||||
preparePrimaryUse({this}, {&other});
|
NDArray::preparePrimaryUse({this}, {&other});
|
||||||
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {&other});
|
NDArray::registerPrimaryUse({this}, {&other});
|
||||||
|
this->syncToDevice();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (dataType() != other.dataType()) {
|
if (dataType() != other.dataType()) {
|
||||||
auto tmp = other.cast(dataType());
|
auto tmp = other.cast(dataType());
|
||||||
prepareSpecialUse({this}, {tmp});
|
NDArray::prepareSpecialUse({this}, {tmp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {});
|
NDArray::registerSpecialUse({this}, {});
|
||||||
delete tmp;
|
delete tmp;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
registerSpecialUse({this}, {&other});
|
NDArray::registerSpecialUse({this}, {&other});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3101,16 +3106,16 @@ void NDArray::assign(const NDArray& other) {
|
||||||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||||
else {
|
else {
|
||||||
prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({this}, {&other});
|
NDArray::registerSpecialUse({this}, {&other});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// This method returns new copy of this NDArray, optionally in different order
|
// This method returns new copy of this NDArray, optionally in different order
|
||||||
NDArray* NDArray::dup(const char newOrder) {
|
NDArray* NDArray::dup(const char newOrder) const {
|
||||||
|
|
||||||
if (isEmpty())
|
if (isEmpty())
|
||||||
return NDArrayFactory::empty_(dataType(), getContext());
|
return NDArrayFactory::empty_(dataType(), getContext());
|
||||||
|
@ -3170,7 +3175,7 @@ std::string NDArray::e(const Nd4jLong i) const {
|
||||||
if (!isS())
|
if (!isS())
|
||||||
throw std::runtime_error("Can't get std::string out of non-string array");
|
throw std::runtime_error("Can't get std::string out of non-string array");
|
||||||
|
|
||||||
preparePrimaryUse({}, {this});
|
NDArray::preparePrimaryUse({}, {this});
|
||||||
|
|
||||||
// getting "virtual" offset. it's not real though,since it doesn't take lengths into account
|
// getting "virtual" offset. it's not real though,since it doesn't take lengths into account
|
||||||
auto offset = getOffset(i);
|
auto offset = getOffset(i);
|
||||||
|
@ -3208,8 +3213,8 @@ T NDArray::e(const Nd4jLong i) const {
|
||||||
|
|
||||||
const auto rp = getOffset(i);
|
const auto rp = getOffset(i);
|
||||||
|
|
||||||
preparePrimaryUse({}, {this});
|
NDArray::preparePrimaryUse({}, {this});
|
||||||
registerPrimaryUse({}, {this});
|
NDArray::registerPrimaryUse({}, {this});
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -3226,8 +3231,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
|
||||||
const Nd4jLong coords[2] = {i, j};
|
const Nd4jLong coords[2] = {i, j};
|
||||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
|
|
||||||
preparePrimaryUse({}, {this});
|
NDArray::preparePrimaryUse({}, {this});
|
||||||
registerPrimaryUse({}, {this});
|
NDArray::registerPrimaryUse({}, {this});
|
||||||
|
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -3246,8 +3251,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
||||||
const Nd4jLong coords[3] = {i, j, k};
|
const Nd4jLong coords[3] = {i, j, k};
|
||||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
|
|
||||||
preparePrimaryUse({}, {this});
|
NDArray::preparePrimaryUse({}, {this});
|
||||||
registerPrimaryUse({}, {this});
|
NDArray::registerPrimaryUse({}, {this});
|
||||||
|
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -3266,8 +3271,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
|
||||||
const Nd4jLong coords[4] = {i, j, k, l};
|
const Nd4jLong coords[4] = {i, j, k, l};
|
||||||
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
|
|
||||||
preparePrimaryUse({}, {this});
|
NDArray::preparePrimaryUse({}, {this});
|
||||||
registerPrimaryUse({}, {this});
|
NDArray::registerPrimaryUse({}, {this});
|
||||||
|
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
@ -3300,9 +3305,9 @@ void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, Extr
|
||||||
if (!target->isR())
|
if (!target->isR())
|
||||||
throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types");
|
throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({target}, {this});
|
NDArray::registerSpecialUse({target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3314,9 +3319,9 @@ void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraA
|
||||||
if (target == nullptr)
|
if (target == nullptr)
|
||||||
target = this;
|
target = this;
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({target}, {this});
|
NDArray::registerSpecialUse({target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3331,9 +3336,9 @@ void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, Extra
|
||||||
if (target->dataType() != dataType())
|
if (target->dataType() != dataType())
|
||||||
throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array");
|
throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({target}, {this});
|
NDArray::registerSpecialUse({target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3347,9 +3352,9 @@ void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, Ext
|
||||||
if (!this->isR() || !target->isR() || (this->dataType() != target->dataType()))
|
if (!this->isR() || !target->isR() || (this->dataType() != target->dataType()))
|
||||||
throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !");
|
throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !");
|
||||||
|
|
||||||
registerSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||||
prepareSpecialUse({target}, {this});
|
NDArray::registerSpecialUse({target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3363,9 +3368,9 @@ void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, Extra
|
||||||
if (!target->isB())
|
if (!target->isB())
|
||||||
throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types");
|
throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
|
||||||
registerSpecialUse({target}, {this});
|
NDArray::registerSpecialUse({target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3375,9 +3380,9 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons
|
||||||
|
|
||||||
NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext());
|
NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext());
|
||||||
|
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3389,9 +3394,9 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const
|
||||||
|
|
||||||
NDArray result(getShapeInfo(), false, getContext());
|
NDArray result(getShapeInfo(), false, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3403,9 +3408,9 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con
|
||||||
|
|
||||||
NDArray result(getShapeInfo(), false, getContext());
|
NDArray result(getShapeInfo(), false, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3417,9 +3422,9 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
|
||||||
|
|
||||||
NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext());
|
NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext());
|
||||||
|
|
||||||
prepareSpecialUse({&result}, {this});
|
NDArray::prepareSpecialUse({&result}, {this});
|
||||||
NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
|
||||||
registerSpecialUse({&result}, {this});
|
NDArray::registerSpecialUse({&result}, {this});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3435,9 +3440,9 @@ void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArra
|
||||||
if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType()))
|
if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType()))
|
||||||
throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!");
|
throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!");
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this, scalar});
|
NDArray::prepareSpecialUse({target}, {this, scalar});
|
||||||
NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
||||||
registerSpecialUse({target}, {this, scalar});
|
NDArray::registerSpecialUse({target}, {this, scalar});
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3471,10 +3476,9 @@ void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, ND
|
||||||
throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!");
|
throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!");
|
||||||
}
|
}
|
||||||
|
|
||||||
prepareSpecialUse({target}, {this, scalar});
|
NDArray::prepareSpecialUse({target}, {this, scalar});
|
||||||
NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
||||||
|
NDArray::registerSpecialUse({target}, {this, scalar});
|
||||||
registerSpecialUse({target}, {this, scalar});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -3557,7 +3561,7 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({result}, {this, other});
|
NDArray::prepareSpecialUse({result}, {this, other});
|
||||||
NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo());
|
NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo());
|
||||||
registerSpecialUse({result}, {this, other});
|
NDArray::registerSpecialUse({result}, {this, other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3635,9 +3639,9 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c
|
||||||
|
|
||||||
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||||
|
|
||||||
prepareSpecialUse({result}, {this, other});
|
NDArray::prepareSpecialUse({result}, {this, other});
|
||||||
NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
registerSpecialUse({result}, {this, other});
|
NDArray::registerSpecialUse({result}, {this, other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -3780,9 +3784,9 @@ void NDArray::p(const Nd4jLong i, const T value) {
|
||||||
auto rp = getOffset(i);
|
auto rp = getOffset(i);
|
||||||
const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value));
|
const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value));
|
||||||
|
|
||||||
preparePrimaryUse({this}, {}, true);
|
NDArray::preparePrimaryUse({this}, {}, true);
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {});
|
NDArray::registerPrimaryUse({this}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
template void NDArray::p(const Nd4jLong i, const double value);
|
template void NDArray::p(const Nd4jLong i, const double value);
|
||||||
|
@ -3811,9 +3815,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
|
||||||
Nd4jLong coords[2] = {i, j};
|
Nd4jLong coords[2] = {i, j};
|
||||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
|
|
||||||
preparePrimaryUse({this}, {}, true);
|
NDArray::preparePrimaryUse({this}, {}, true);
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {});
|
NDArray::registerPrimaryUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
|
||||||
|
@ -3837,13 +3841,13 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
|
||||||
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
|
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
|
||||||
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
|
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
|
||||||
|
|
||||||
preparePrimaryUse({this}, {}, true);
|
NDArray::preparePrimaryUse({this}, {}, true);
|
||||||
|
|
||||||
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
|
||||||
Nd4jLong coords[3] = {i, j, k};
|
Nd4jLong coords[3] = {i, j, k};
|
||||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {});
|
NDArray::registerPrimaryUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
|
||||||
|
@ -3870,9 +3874,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
|
||||||
Nd4jLong coords[4] = {i, j, k, l};
|
Nd4jLong coords[4] = {i, j, k, l};
|
||||||
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
|
||||||
|
|
||||||
preparePrimaryUse({this}, {}, true);
|
NDArray::preparePrimaryUse({this}, {}, true);
|
||||||
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {});
|
NDArray::registerPrimaryUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
|
||||||
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
|
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
|
||||||
|
@ -3896,10 +3900,10 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
||||||
if (i >= _length)
|
if (i >= _length)
|
||||||
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
||||||
|
|
||||||
preparePrimaryUse({this}, {&scalar}, true);
|
NDArray::preparePrimaryUse({this}, {&scalar}, true);
|
||||||
auto rp = getOffset(i);
|
auto rp = getOffset(i);
|
||||||
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
|
||||||
registerPrimaryUse({this}, {&scalar});
|
NDArray::registerPrimaryUse({this}, {&scalar});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4195,7 +4199,7 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
|
||||||
|
|
||||||
|
|
||||||
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
|
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
|
||||||
auto numTads = lengthOf() / shape::length(pack.primaryShapeInfo());
|
auto numTads = pack.numberOfTads();
|
||||||
|
|
||||||
for (int idx = 0; idx < numTads; idx++ ) {
|
for (int idx = 0; idx < numTads; idx++ ) {
|
||||||
auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset());
|
auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset());
|
||||||
|
|
|
@ -1578,6 +1578,20 @@ public:
|
||||||
void *dx, Nd4jLong *dxShapeInfo,
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
bool descending);
|
bool descending);
|
||||||
|
|
||||||
|
void sortByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending);
|
||||||
|
|
||||||
|
void sortByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending);
|
||||||
|
|
||||||
void sortTad(Nd4jPointer *extraPointers,
|
void sortTad(Nd4jPointer *extraPointers,
|
||||||
void *x, Nd4jLong *xShapeInfo,
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
void *dx, Nd4jLong *dxShapeInfo,
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
@ -1587,6 +1601,24 @@ public:
|
||||||
Nd4jLong *tadOffsets,
|
Nd4jLong *tadOffsets,
|
||||||
bool descending);
|
bool descending);
|
||||||
|
|
||||||
|
void sortTadByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending);
|
||||||
|
|
||||||
|
void sortTadByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending);
|
||||||
|
|
||||||
|
|
||||||
// special sort impl for sorting out COO indices and values
|
// special sort impl for sorting out COO indices and values
|
||||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
|
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
|
||||||
|
|
|
@ -208,6 +208,23 @@ void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* NDArray::specialBuffer() {
|
||||||
|
if (_buffer->special() == nullptr)
|
||||||
|
return getBuffer();
|
||||||
|
// FIXME: this should be fixed once CUDA backend added
|
||||||
|
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* NDArray::getSpecialBuffer() const {
|
||||||
|
if (_buffer->special() == nullptr)
|
||||||
|
return getBuffer();
|
||||||
|
// FIXME: this should be fixed once CUDA backend added
|
||||||
|
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// change an array by repeating it the number of times given by reps.
|
// change an array by repeating it the number of times given by reps.
|
||||||
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
|
|
|
@ -27,6 +27,52 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <>
|
||||||
|
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
|
||||||
|
|
||||||
|
if ((int) shape.size() > MAX_RANK)
|
||||||
|
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
|
||||||
|
|
||||||
|
if (descriptor.arrLength() != data.size()) {
|
||||||
|
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||||
|
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool* hostBuffer = nullptr;
|
||||||
|
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
|
||||||
|
std::copy(data.begin(), data.end(), hostBuffer);
|
||||||
|
|
||||||
|
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
|
||||||
|
|
||||||
|
NDArray result(buffer, descriptor, context);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
|
||||||
|
|
||||||
|
if ((int) shape.size() > MAX_RANK)
|
||||||
|
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
|
||||||
|
|
||||||
|
if (descriptor.arrLength() != data.size()) {
|
||||||
|
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
||||||
|
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
|
||||||
|
|
||||||
|
NDArray result(buffer, descriptor, context);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
|
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
|
||||||
std::string s(str);
|
std::string s(str);
|
||||||
|
@ -227,10 +273,13 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context);
|
||||||
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned int> &data, nd4j::LaunchContext * context);
|
||||||
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned long> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context);
|
||||||
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint16_t> &data, nd4j::LaunchContext * context);
|
||||||
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context);
|
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context);
|
||||||
|
|
||||||
|
|
||||||
|
@ -391,6 +440,7 @@ template NDArray NDArrayFactory::create(const std::vector<bfloat16> &values, nd4
|
||||||
template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context);
|
||||||
|
template NDArray NDArrayFactory::create(const std::vector<uint16_t> &values, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context);
|
||||||
template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context);
|
template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context);
|
||||||
|
@ -452,53 +502,6 @@ template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::L
|
||||||
return new NDArray(order, shape, dataType, context);
|
return new NDArray(order, shape, dataType, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
|
|
||||||
|
|
||||||
if ((int) shape.size() > MAX_RANK)
|
|
||||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
|
||||||
|
|
||||||
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
|
|
||||||
|
|
||||||
if (descriptor.arrLength() != data.size()) {
|
|
||||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
|
||||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
|
|
||||||
|
|
||||||
NDArray result(buffer, descriptor, context);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
|
|
||||||
}
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template <>
|
|
||||||
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
|
|
||||||
|
|
||||||
if ((int) shape.size() > MAX_RANK)
|
|
||||||
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
|
|
||||||
|
|
||||||
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
|
|
||||||
|
|
||||||
if (descriptor.arrLength() != data.size()) {
|
|
||||||
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
|
|
||||||
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool* hostBuffer = nullptr;
|
|
||||||
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
|
|
||||||
std::copy(data.begin(), data.end(), hostBuffer);
|
|
||||||
|
|
||||||
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
|
|
||||||
|
|
||||||
NDArray result(buffer, descriptor, context);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) {
|
NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) {
|
||||||
|
|
|
@ -2736,6 +2736,60 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending) {
|
||||||
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending) {
|
||||||
|
|
||||||
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending) {
|
||||||
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dx, Nd4jLong *dxShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending) {
|
||||||
|
auto xType = ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
|
|
|
@ -192,8 +192,8 @@ void NDArray::setIdentity() {
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!");
|
||||||
|
|
||||||
if (rankOf() != 2)
|
// if (rankOf() != 2)
|
||||||
throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
|
// throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
@ -234,22 +234,27 @@ void NDArray::synchronize(const char* msg) const {
|
||||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
|
if(a != nullptr)
|
||||||
a->syncToDevice();
|
a->syncToDevice();
|
||||||
|
|
||||||
for (const auto& a : writeList) {
|
for (const auto& a : writeList) {
|
||||||
|
if (a != nullptr) {
|
||||||
a->getDataBuffer()->allocateSpecial();
|
a->getDataBuffer()->allocateSpecial();
|
||||||
if (synchronizeWritables)
|
if (synchronizeWritables)
|
||||||
a->syncToDevice();
|
a->syncToDevice();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
|
if(p != nullptr)
|
||||||
p->tickReadDevice();
|
p->tickReadDevice();
|
||||||
|
|
||||||
for (const auto& p : writeList)
|
for (const auto& p : writeList)
|
||||||
|
if (p != nullptr)
|
||||||
p->tickWriteDevice();
|
p->tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,22 +262,27 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
|
||||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
|
if(a != nullptr)
|
||||||
a->syncToHost();
|
a->syncToHost();
|
||||||
|
|
||||||
for (const auto& a : writeList) {
|
for (const auto& a : writeList) {
|
||||||
|
if (a != nullptr) {
|
||||||
a->getDataBuffer()->allocatePrimary();
|
a->getDataBuffer()->allocatePrimary();
|
||||||
if (synchronizeWritables)
|
if (synchronizeWritables)
|
||||||
a->syncToHost();
|
a->syncToHost();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
|
if(p != nullptr)
|
||||||
p->tickReadHost();
|
p->tickReadHost();
|
||||||
|
|
||||||
for (const auto& p : writeList)
|
for (const auto& p : writeList)
|
||||||
|
if (p != nullptr)
|
||||||
p->tickWriteHost();
|
p->tickWriteHost();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -427,9 +437,26 @@ void NDArray::repeat(int dimension, NDArray& target) const {
|
||||||
NDArray::registerSpecialUse({&target}, {this});
|
NDArray::registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* NDArray::specialBuffer() {
|
||||||
|
|
||||||
|
if (_buffer->special() == nullptr)
|
||||||
|
return getBuffer();
|
||||||
|
// FIXME: this should be fixed once CUDA backend added
|
||||||
|
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void* NDArray::getSpecialBuffer() const {
|
||||||
|
if (_buffer->special() == nullptr)
|
||||||
|
return getBuffer();
|
||||||
|
// FIXME: this should be fixed once CUDA backend added
|
||||||
|
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {\
|
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {
|
||||||
|
|
||||||
if(_length == 0)
|
if(_length == 0)
|
||||||
{ printf("NDArray::printActualBuffer: array length is zero !\n"); return; }
|
{ printf("NDArray::printActualBuffer: array length is zero !\n"); return; }
|
||||||
|
@ -477,7 +504,7 @@ template void NDArray::printCurrentBuffer<double>(const bool host, const char* m
|
||||||
|
|
||||||
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
|
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
|
||||||
|
|
||||||
#include <cpu/NDArrayLambda.hpp>
|
//#include <cpu/NDArrayLambda.hpp>
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -2321,6 +2321,163 @@ void NativeOps::sort(Nd4jPointer *extraPointers,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending) {
|
||||||
|
|
||||||
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
|
||||||
|
auto xLength = shape::length(xShapeInfo);
|
||||||
|
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||||
|
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
|
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||||
|
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||||
|
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||||
|
int numBlocks = xLength / numThreads;
|
||||||
|
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||||
|
numBlocks++;
|
||||||
|
|
||||||
|
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||||
|
|
||||||
|
for (int k = 2; k <= xLength; k = 2*k) {
|
||||||
|
for (int j = k >> 1; j > 0; j = j >> 1) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||||
|
int numBlocks = xLength / numThreads;
|
||||||
|
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||||
|
numBlocks++;
|
||||||
|
|
||||||
|
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
|
||||||
|
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||||
|
|
||||||
|
int max = 2, dg = 0;
|
||||||
|
while (max < xLength) {
|
||||||
|
max <<= 1;
|
||||||
|
dg++;
|
||||||
|
}
|
||||||
|
max <<= 1;
|
||||||
|
|
||||||
|
for (int window = 2; window < max; window<<=1) {
|
||||||
|
int n = window;
|
||||||
|
int rev = 0;
|
||||||
|
do{
|
||||||
|
int half = n >> 1;
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
n>>=1;
|
||||||
|
rev = 1;
|
||||||
|
} while(n > 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
bool descending) {
|
||||||
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
|
||||||
|
auto xLength = shape::length(xShapeInfo);
|
||||||
|
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||||
|
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
|
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||||
|
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||||
|
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||||
|
int numBlocks = xLength / numThreads;
|
||||||
|
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||||
|
numBlocks++;
|
||||||
|
|
||||||
|
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||||
|
|
||||||
|
for (int k = 2; k <= xLength; k = 2*k) {
|
||||||
|
for (int j = k >> 1; j > 0; j = j >> 1) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
|
||||||
|
int numBlocks = xLength / numThreads;
|
||||||
|
if (xLength % numThreads > 0 || numBlocks == 0)
|
||||||
|
numBlocks++;
|
||||||
|
|
||||||
|
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
|
||||||
|
dim3 launchDims(numBlocks, numThreads, 32768);
|
||||||
|
|
||||||
|
int max = 2, dg = 0;
|
||||||
|
while (max < xLength) {
|
||||||
|
max <<= 1;
|
||||||
|
dg++;
|
||||||
|
}
|
||||||
|
max <<= 1;
|
||||||
|
|
||||||
|
for (int window = 2; window < max; window<<=1) {
|
||||||
|
int n = window;
|
||||||
|
int rev = 0;
|
||||||
|
do{
|
||||||
|
int half = n >> 1;
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
n>>=1;
|
||||||
|
rev = 1;
|
||||||
|
} while(n > 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending) {
|
||||||
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||||
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||||
|
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||||
|
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
|
||||||
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
void *y, Nd4jLong *yShapeInfo,
|
||||||
|
void *dy, Nd4jLong *dyShapeInfo,
|
||||||
|
int *dimension,
|
||||||
|
int dimensionLength,
|
||||||
|
bool descending) {
|
||||||
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||||
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||||
|
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
|
||||||
|
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||||
|
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
||||||
|
nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void NativeOps::sortTad(Nd4jPointer *extraPointers,
|
void NativeOps::sortTad(Nd4jPointer *extraPointers,
|
||||||
void *x, Nd4jLong *xShapeInfo,
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -2331,15 +2488,13 @@ void NativeOps::sortTad(Nd4jPointer *extraPointers,
|
||||||
bool descending) {
|
bool descending) {
|
||||||
// to be implemented
|
// to be implemented
|
||||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
|
||||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||||
|
dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768);
|
||||||
dim3 launchDims(tadPack.numberOfTads(), 1024, 33768);
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||||
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "sortTadFloat(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
||||||
|
|
|
@ -38,11 +38,11 @@ namespace nd4j {
|
||||||
ConstantDataBuffer() = default;
|
ConstantDataBuffer() = default;
|
||||||
~ConstantDataBuffer() = default;
|
~ConstantDataBuffer() = default;
|
||||||
|
|
||||||
Nd4jLong sizeOf();
|
Nd4jLong sizeOf() const;
|
||||||
Nd4jLong length();
|
Nd4jLong length() const;
|
||||||
|
|
||||||
Nd4jPointer primary();
|
Nd4jPointer primary() const;
|
||||||
Nd4jPointer special();
|
Nd4jPointer special() const;
|
||||||
|
|
||||||
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
|
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
|
||||||
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;
|
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;
|
||||||
|
|
|
@ -261,6 +261,8 @@ DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
|
||||||
|
|
||||||
allocateBuffers();
|
allocateBuffers();
|
||||||
copyBufferFrom(other);
|
copyBufferFrom(other);
|
||||||
|
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -285,6 +287,8 @@ DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
|
||||||
other._primaryBuffer = other._specialBuffer = nullptr;
|
other._primaryBuffer = other._specialBuffer = nullptr;
|
||||||
other.setAllocFlags(false, false);
|
other.setAllocFlags(false, false);
|
||||||
other._lenInBytes = 0;
|
other._lenInBytes = 0;
|
||||||
|
|
||||||
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -335,6 +335,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
return std::string("INT8");
|
return std::string("INT8");
|
||||||
case INT16:
|
case INT16:
|
||||||
return std::string("INT16");
|
return std::string("INT16");
|
||||||
|
case UINT16:
|
||||||
|
return std::string("UINT16");
|
||||||
case INT32:
|
case INT32:
|
||||||
return std::string("INT32");
|
return std::string("INT32");
|
||||||
case INT64:
|
case INT64:
|
||||||
|
@ -375,7 +377,7 @@ FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo,
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE T DataTypeUtils::eps() {
|
FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
||||||
if (std::is_same<T, double>::value)
|
if (std::is_same<T, double>::value)
|
||||||
return std::numeric_limits<double>::epsilon();
|
return std::numeric_limits<double>::epsilon();
|
||||||
else if (std::is_same<T, float>::value)
|
else if (std::is_same<T, float>::value)
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <array/DataType.h>
|
#include <array/DataType.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ND4J_EXPORT ExtraArguments {
|
class ND4J_EXPORT ExtraArguments {
|
||||||
|
|
|
@ -35,21 +35,21 @@ namespace nd4j {
|
||||||
TadPack() = default;
|
TadPack() = default;
|
||||||
~TadPack() = default;
|
~TadPack() = default;
|
||||||
|
|
||||||
Nd4jLong* primaryShapeInfo();
|
Nd4jLong* primaryShapeInfo() const;
|
||||||
Nd4jLong* primaryOffsets();
|
Nd4jLong* primaryOffsets() const;
|
||||||
|
|
||||||
Nd4jLong* specialShapeInfo();
|
Nd4jLong* specialShapeInfo() const;
|
||||||
Nd4jLong* specialOffsets();
|
Nd4jLong* specialOffsets() const;
|
||||||
|
|
||||||
Nd4jLong numberOfTads();
|
Nd4jLong numberOfTads() const;
|
||||||
int shapeInfoLength();
|
int shapeInfoLength() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* These methods return either primary or special pointers depending on platform binaries were compiled for
|
* These methods return either primary or special pointers depending on platform binaries were compiled for
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jLong *platformShapeInfo();
|
Nd4jLong *platformShapeInfo() const;
|
||||||
Nd4jLong *platformOffsets();
|
Nd4jLong *platformOffsets() const;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,19 +28,19 @@ namespace nd4j {
|
||||||
_sizeOf = sizeOf;
|
_sizeOf = sizeOf;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jPointer ConstantDataBuffer::primary() {
|
Nd4jPointer ConstantDataBuffer::primary() const {
|
||||||
return _primaryBuffer;
|
return _primaryBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jPointer ConstantDataBuffer::special() {
|
Nd4jPointer ConstantDataBuffer::special() const {
|
||||||
return _specialBuffer;
|
return _specialBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantDataBuffer::sizeOf() {
|
Nd4jLong ConstantDataBuffer::sizeOf() const {
|
||||||
return _sizeOf;
|
return _sizeOf;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong ConstantDataBuffer::length() {
|
Nd4jLong ConstantDataBuffer::length() const {
|
||||||
return _length;
|
return _length;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ namespace nd4j {
|
||||||
NDArray* NDArrayList::readRaw(int idx) {
|
NDArray* NDArrayList::readRaw(int idx) {
|
||||||
if (_chunks.count(idx) < 1) {
|
if (_chunks.count(idx) < 1) {
|
||||||
nd4j_printf("Non-existent chunk requested: [%i]\n", idx);
|
nd4j_printf("Non-existent chunk requested: [%i]\n", idx);
|
||||||
throw std::runtime_error("Bad index");
|
throw std::invalid_argument("Bad index");
|
||||||
}
|
}
|
||||||
|
|
||||||
return _chunks[idx];
|
return _chunks[idx];
|
||||||
|
@ -120,7 +120,7 @@ namespace nd4j {
|
||||||
// storing reference
|
// storing reference
|
||||||
_chunks[idx] = array;
|
_chunks[idx] = array;
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong>& NDArrayList::shape() {
|
std::vector<Nd4jLong>& NDArrayList::shape() {
|
||||||
|
@ -152,8 +152,10 @@ namespace nd4j {
|
||||||
std::vector<bool> bargs;
|
std::vector<bool> bargs;
|
||||||
int numElements = _elements.load();
|
int numElements = _elements.load();
|
||||||
|
|
||||||
for (int e = 0; e < numElements; e++)
|
for (int e = 0; e < numElements; e++) {
|
||||||
|
_chunks[e]->syncToDevice();
|
||||||
inputs.emplace_back(_chunks[e]);
|
inputs.emplace_back(_chunks[e]);
|
||||||
|
}
|
||||||
|
|
||||||
iargs.push_back(_axis);
|
iargs.push_back(_axis);
|
||||||
|
|
||||||
|
|
|
@ -29,34 +29,34 @@ namespace nd4j {
|
||||||
_numTads = numTads;
|
_numTads = numTads;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* TadPack::primaryShapeInfo() {
|
Nd4jLong* TadPack::primaryShapeInfo() const {
|
||||||
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
|
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
|
||||||
}
|
}
|
||||||
Nd4jLong* TadPack::primaryOffsets() {
|
Nd4jLong* TadPack::primaryOffsets() const {
|
||||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
|
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* TadPack::specialShapeInfo() {
|
Nd4jLong* TadPack::specialShapeInfo() const {
|
||||||
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
|
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* TadPack::specialOffsets() {
|
Nd4jLong* TadPack::specialOffsets() const {
|
||||||
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
|
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong TadPack::numberOfTads() {
|
Nd4jLong TadPack::numberOfTads() const {
|
||||||
return _numTads;
|
return _numTads;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* TadPack::platformShapeInfo() {
|
Nd4jLong* TadPack::platformShapeInfo() const {
|
||||||
return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
|
return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* TadPack::platformOffsets() {
|
Nd4jLong* TadPack::platformOffsets() const {
|
||||||
return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
|
return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
|
||||||
}
|
}
|
||||||
|
|
||||||
int TadPack::shapeInfoLength() {
|
int TadPack::shapeInfoLength() const {
|
||||||
return (int) shape::shapeInfoLength(primaryShapeInfo());
|
return (int) shape::shapeInfoLength(primaryShapeInfo());
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -27,7 +27,7 @@ namespace nd4j {
|
||||||
class AttentionHelper {
|
class AttentionHelper {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static nd4j::NDArray* multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static nd4j::NDArray multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,10 +69,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void executeOnce() override {
|
void executeOnce() override {
|
||||||
auto xT = (_tA ? _x->transpose() : _x);
|
auto xT = (_tA ? _x->transpose() : *_x);
|
||||||
auto yT = (_tB ? _y->transpose() : _y);
|
auto yT = (_tB ? _y->transpose() : *_y);
|
||||||
|
|
||||||
MmulHelper::mmul(xT, yT, _z, _alpha, _beta);
|
MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string axis() override {
|
std::string axis() override {
|
||||||
|
|
|
@ -133,9 +133,9 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
|
||||||
// if(matrix.rankOf() != 2)
|
// if(matrix.rankOf() != 2)
|
||||||
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
|
||||||
|
|
||||||
if(matrix.sizeAt(0) == 1)
|
if(matrix.sizeAt(0) == 1) {
|
||||||
matrix *= (T)1.f - coeff;
|
matrix *= (T) 1.f - coeff;
|
||||||
|
}
|
||||||
else if(coeff != (T)0.f) {
|
else if(coeff != (T)0.f) {
|
||||||
|
|
||||||
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
|
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
|
||||||
|
@ -145,13 +145,11 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
|
||||||
|
|
||||||
auto column = tail;
|
auto column = tail;
|
||||||
auto row = tail.transpose();
|
auto row = tail.transpose();
|
||||||
auto resultingRow = mmul(*row, bottomPartCopy);
|
auto resultingRow = mmul(row, bottomPartCopy);
|
||||||
auto fistRow = matrix({0,1, 0,0}, true);
|
auto fistRow = matrix({0,1, 0,0}, true);
|
||||||
resultingRow += fistRow;
|
resultingRow += fistRow;
|
||||||
fistRow -= resultingRow * coeff;
|
fistRow -= resultingRow * coeff;
|
||||||
*bottomPart -= mmul(column, resultingRow) * coeff;
|
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||||
|
|
||||||
delete row;
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -161,9 +159,7 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
|
||||||
auto fistRow = matrix({0,1, 0,0}, true);
|
auto fistRow = matrix({0,1, 0,0}, true);
|
||||||
resultingRow += fistRow;
|
resultingRow += fistRow;
|
||||||
fistRow -= resultingRow * coeff;
|
fistRow -= resultingRow * coeff;
|
||||||
*bottomPart -= mmul(*column, resultingRow) * coeff;
|
*bottomPart -= mmul(column, resultingRow) * coeff;
|
||||||
|
|
||||||
delete column;
|
|
||||||
}
|
}
|
||||||
delete bottomPart;
|
delete bottomPart;
|
||||||
}
|
}
|
||||||
|
@ -193,21 +189,16 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
|
||||||
auto resultingCol = mmul(rightPartCopy, column);
|
auto resultingCol = mmul(rightPartCopy, column);
|
||||||
resultingCol += *fistCol;
|
resultingCol += *fistCol;
|
||||||
*fistCol -= resultingCol * coeff;
|
*fistCol -= resultingCol * coeff;
|
||||||
*rightPart -= mmul(resultingCol, *row) * coeff;
|
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||||
|
|
||||||
delete row;
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
auto row = tail;
|
auto row = tail;
|
||||||
auto column = tail.transpose();
|
auto column = tail.transpose();
|
||||||
auto resultingCol = mmul(rightPartCopy, *column);
|
auto resultingCol = mmul(rightPartCopy, column);
|
||||||
resultingCol += *fistCol;
|
resultingCol += *fistCol;
|
||||||
*fistCol -= resultingCol * coeff;
|
*fistCol -= resultingCol * coeff;
|
||||||
*rightPart -= mmul(resultingCol, row) * coeff;
|
*rightPart -= mmul(resultingCol, row) * coeff;
|
||||||
|
|
||||||
delete column;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
delete rightPart;
|
delete rightPart;
|
||||||
delete fistCol;
|
delete fistCol;
|
||||||
|
|
|
@ -157,8 +157,7 @@ bool JacobiSVD<T>::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) {
|
||||||
|
|
||||||
if(_calcU) {
|
if(_calcU) {
|
||||||
auto temp2 = rotation.transpose();
|
auto temp2 = rotation.transpose();
|
||||||
mulRotationOnRight(p, q, _u, *temp2);
|
mulRotationOnRight(p, q, _u, temp2);
|
||||||
delete temp2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,9 +250,7 @@ void JacobiSVD<T>::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDA
|
||||||
m.p<T>(1, 1, _z);
|
m.p<T>(1, 1, _z);
|
||||||
|
|
||||||
auto temp = right.transpose();
|
auto temp = right.transpose();
|
||||||
left.assign(mmul(rotation, *temp));
|
left.assign(mmul(rotation, temp));
|
||||||
delete temp;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,7 +286,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
||||||
else if(_rows < _cols) {
|
else if(_rows < _cols) {
|
||||||
|
|
||||||
auto matrixT = matrix.transpose();
|
auto matrixT = matrix.transpose();
|
||||||
HHcolPivQR qr(*matrixT / scale);
|
HHcolPivQR qr(matrixT / scale);
|
||||||
_m.assign(qr._qr({0,_rows, 0,_rows}));
|
_m.assign(qr._qr({0,_rows, 0,_rows}));
|
||||||
_m.fillAsTriangular<T>(0., 0, 0, 'l');
|
_m.fillAsTriangular<T>(0., 0, 0, 'l');
|
||||||
_m.transposei();
|
_m.transposei();
|
||||||
|
@ -305,8 +302,6 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
||||||
|
|
||||||
if(_calcU)
|
if(_calcU)
|
||||||
_u.assign(qr._permut);
|
_u.assign(qr._permut);
|
||||||
|
|
||||||
delete matrixT;
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
|
@ -352,8 +347,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
|
||||||
|
|
||||||
if(_calcU) {
|
if(_calcU) {
|
||||||
auto temp = rotLeft.transpose();
|
auto temp = rotLeft.transpose();
|
||||||
mulRotationOnRight(p, q, _u, *temp);
|
mulRotationOnRight(p, q, _u, temp);
|
||||||
delete temp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mulRotationOnRight(p, q, _m, rotRight);
|
mulRotationOnRight(p, q, _m, rotRight);
|
||||||
|
|
|
@ -920,7 +920,7 @@ void SVD<T>::evalData(const NDArray& matrix) {
|
||||||
auto temp1 = biDiag._HHbidiag.transpose();
|
auto temp1 = biDiag._HHbidiag.transpose();
|
||||||
auto temp2 = _m({0,_diagSize, 0,0}, true);
|
auto temp2 = _m({0,_diagSize, 0,0}, true);
|
||||||
temp2.assign(temp1);
|
temp2.assign(temp1);
|
||||||
delete temp1;
|
|
||||||
|
|
||||||
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
|
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
|
||||||
temp3.assign(0.);
|
temp3.assign(0.);
|
||||||
|
|
|
@ -184,9 +184,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
|
|
||||||
if(pC->ordering() != 'f') {
|
if(pC->ordering() != 'f') {
|
||||||
auto temp = pA;
|
auto temp = pA;
|
||||||
pA = pB ->permute({1,0});
|
pA = new NDArray(pB ->permute({1,0}));
|
||||||
pB = temp->permute({1,0});
|
pB = new NDArray(temp->permute({1,0}));
|
||||||
pC = pC ->permute({1,0});
|
pC = new NDArray(pC ->permute({1,0}));
|
||||||
toDelete.push_back(pA);
|
toDelete.push_back(pA);
|
||||||
toDelete.push_back(pB);
|
toDelete.push_back(pB);
|
||||||
toDelete.push_back(pC);
|
toDelete.push_back(pC);
|
||||||
|
@ -251,7 +251,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
||||||
|
@ -339,7 +340,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
||||||
threadsPerBlock.x = 512;
|
threadsPerBlock.x = 512;
|
||||||
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
||||||
}
|
}
|
||||||
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
||||||
|
@ -396,7 +398,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({Z}, {X, Y});
|
NDArray::prepareSpecialUse({Z}, {X, Y});
|
||||||
|
|
||||||
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES)
|
||||||
|
|
||||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
||||||
|
@ -406,8 +409,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
||||||
return Z;
|
return Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
}
|
}
|
|
@ -28,33 +28,27 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
nd4j::NDArray *
|
nd4j::NDArray AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
|
||||||
AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
|
|
||||||
auto miniBatchSize = input->sizeAt(0);
|
auto miniBatchSize = input->sizeAt(0);
|
||||||
auto seqLength = input->sizeAt(2);
|
auto seqLength = input->sizeAt(2);
|
||||||
auto numHeads = projectionMatrix->sizeAt(0);
|
auto numHeads = projectionMatrix->sizeAt(0);
|
||||||
auto projectedSize = projectionMatrix->sizeAt(1);
|
auto projectedSize = projectionMatrix->sizeAt(1);
|
||||||
|
|
||||||
auto inputPerm = input->permute({1, 0, 2});
|
auto inputPerm = input->permute({1, 0, 2});
|
||||||
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||||
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||||
|
|
||||||
NDArray* projected = new NDArray('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
|
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
|
||||||
nd4j::ops::matmul mmul;
|
nd4j::ops::matmul mmul;
|
||||||
mmul.execute({projectionPrep, inputPrep}, {projected}, {}, {}, {});
|
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
||||||
|
|
||||||
projected->reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||||
projected->permutei({2, 0, 1, 3});
|
projected.permutei({2, 0, 1, 3});
|
||||||
|
|
||||||
delete inputPerm;
|
|
||||||
delete inputPrep;
|
|
||||||
delete projectionPrep;
|
|
||||||
|
|
||||||
return projected;
|
return projected;
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
|
||||||
AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
|
|
||||||
const nd4j::NDArray *eps, nd4j::NDArray *dLdInput,
|
const nd4j::NDArray *eps, nd4j::NDArray *dLdInput,
|
||||||
nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) {
|
nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) {
|
||||||
auto miniBatchSize = input->sizeAt(0);
|
auto miniBatchSize = input->sizeAt(0);
|
||||||
|
@ -63,16 +57,16 @@ namespace nd4j {
|
||||||
auto projectedSize = projectionMatrix->sizeAt(1);
|
auto projectedSize = projectionMatrix->sizeAt(1);
|
||||||
|
|
||||||
auto epsPerm = eps->permute({1, 2, 0, 3});
|
auto epsPerm = eps->permute({1, 2, 0, 3});
|
||||||
auto epsReshaped = epsPerm->reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
|
auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
|
||||||
|
|
||||||
auto inputPerm = input->permute({1, 0, 2});
|
auto inputPerm = input->permute({1, 0, 2});
|
||||||
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
||||||
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||||
|
|
||||||
nd4j::ops::matmul_bp mmulBp;
|
nd4j::ops::matmul_bp mmulBp;
|
||||||
NDArray dLdProjectionPrep(projectionPrep->shapeInfo(), false, context);
|
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
|
||||||
NDArray dLdInputPrep(inputPrep->shapeInfo(), false, context);
|
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
|
||||||
mmulBp.execute({projectionPrep, inputPrep, epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
|
||||||
|
|
||||||
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
||||||
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
dLdProjectionMatrix->assign(dLdProjectionPrep);
|
||||||
|
@ -80,12 +74,6 @@ namespace nd4j {
|
||||||
dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
|
dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
|
||||||
dLdInputPrep.permutei({1, 0, 2});
|
dLdInputPrep.permutei({1, 0, 2});
|
||||||
dLdInput->assign(dLdInputPrep);
|
dLdInput->assign(dLdInputPrep);
|
||||||
|
|
||||||
delete inputPerm;
|
|
||||||
delete inputPrep;
|
|
||||||
delete epsPerm;
|
|
||||||
delete epsReshaped;
|
|
||||||
delete projectionPrep;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
|
||||||
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
|
||||||
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
|
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
|
||||||
|
|
||||||
const int numInArrsFF = argsHolderFF.getNumInArrs(); // also numInArrsFF = number of output arrays in opBP
|
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
|
||||||
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
|
||||||
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
||||||
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
||||||
|
@ -65,6 +65,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||||
|
|
||||||
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
||||||
|
|
||||||
for(int i = 0; i < numInArrsFF; ++i) { // loop through input array
|
for(int i = 0; i < numInArrsFF; ++i) { // loop through input array
|
||||||
|
|
||||||
if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false)
|
if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false)
|
||||||
|
@ -75,39 +76,39 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
|
|
||||||
for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array
|
for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array
|
||||||
|
|
||||||
double& elem = inArrsFF[i]->t<double>(j);
|
const double orig = inArrsFF[i]->e<double>(j);
|
||||||
const double orig = elem;
|
|
||||||
|
|
||||||
// add epsilon, feed forward
|
// add epsilon, feed forward
|
||||||
elem = orig + EPSILON;
|
inArrsFF[i]->p<double>(j, orig + EPSILON);
|
||||||
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
|
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
|
||||||
int numOutArrs = outArrsFF->size();
|
int numOutArrs = outArrsFF->size();
|
||||||
double scorePlus = 0.;
|
double scorePlus = 0.;
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
|
|
||||||
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
if(loss == SUM)
|
if(loss == SUM)
|
||||||
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
else
|
else
|
||||||
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
scorePlus += tmpScalar.e<double>(0);
|
scorePlus += tmpScalar.e<double>(0);
|
||||||
}
|
}
|
||||||
delete outArrsFF;
|
delete outArrsFF;
|
||||||
|
|
||||||
// subtract epsilon, feed forward
|
// subtract epsilon, feed forward
|
||||||
elem = orig - EPSILON;
|
inArrsFF[i]->p<double>(j, orig - EPSILON);
|
||||||
outArrsFF = opFF.execute(argsHolderFF);
|
outArrsFF = opFF.execute(argsHolderFF);
|
||||||
double scoreMinus = 0.;
|
double scoreMinus = 0.;
|
||||||
|
|
||||||
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
|
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
|
||||||
if(loss == SUM)
|
if(loss == SUM)
|
||||||
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
|
||||||
else
|
else
|
||||||
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
|
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
|
||||||
scoreMinus += tmpScalar.e<double>(0);
|
scoreMinus += tmpScalar.e<double>(0);
|
||||||
}
|
}
|
||||||
delete outArrsFF;
|
delete outArrsFF;
|
||||||
|
|
||||||
// restore initial element value
|
// restore initial element value
|
||||||
elem = orig;
|
inArrsFF[i]->p<double>(j, orig);
|
||||||
|
|
||||||
// calculate numerical gradient
|
// calculate numerical gradient
|
||||||
const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON);
|
const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON);
|
||||||
|
|
|
@ -43,22 +43,19 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
|
||||||
|
|
||||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
NDArray* aPR = a->permute(permutAt);
|
NDArray aPR = a->permute(permutAt);
|
||||||
NDArray* bPR = b->permute(permutBt);
|
NDArray bPR = b->permute(permutBt);
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR->isSameShape(shapeAt))
|
if(!aPR.isSameShape(shapeAt))
|
||||||
aPR->reshapei( shapeAt);
|
aPR.reshapei( shapeAt);
|
||||||
if(!bPR->isSameShape(shapeBt))
|
if(!bPR.isSameShape(shapeBt))
|
||||||
bPR->reshapei( shapeBt);
|
bPR.reshapei( shapeBt);
|
||||||
|
|
||||||
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
||||||
|
|
||||||
c->reshapei(outShape);
|
c->reshapei(outShape);
|
||||||
|
|
||||||
delete aPR;
|
|
||||||
delete bPR;
|
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,21 +71,21 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
||||||
|
|
||||||
// check whether permutation is required
|
// check whether permutation is required
|
||||||
if(!permutForC.empty())
|
if(!permutForC.empty())
|
||||||
cP = c->permute(permutForC);
|
cP = new NDArray(c->permute(permutForC));
|
||||||
|
|
||||||
auto aPR = a->permute(permutAt);
|
auto aPR = a->permute(permutAt);
|
||||||
auto bPR = b->permute(permutBt);
|
auto bPR = b->permute(permutBt);
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR->isSameShape(shapeAt))
|
if(!aPR.isSameShape(shapeAt))
|
||||||
aPR->reshapei(shapeAt);
|
aPR.reshapei(shapeAt);
|
||||||
if(!bPR->isSameShape(shapeBt))
|
if(!bPR.isSameShape(shapeBt))
|
||||||
bPR->reshapei(shapeBt);
|
bPR.reshapei(shapeBt);
|
||||||
|
|
||||||
if(!cP->isSameShape({aPR->sizeAt(0), bPR->sizeAt(1)}))
|
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
||||||
cPR = cP->reshape(cP->ordering(), {aPR->sizeAt(0), bPR->sizeAt(1)});
|
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
||||||
|
|
||||||
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
||||||
|
|
||||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||||
cP->assign(cPR);
|
cP->assign(cPR);
|
||||||
|
@ -97,40 +94,42 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
||||||
delete cPR;
|
delete cPR;
|
||||||
if(cP != c)
|
if(cP != c)
|
||||||
delete cP;
|
delete cP;
|
||||||
delete aPR;
|
|
||||||
delete bPR;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) {
|
void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) {
|
||||||
|
|
||||||
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
||||||
std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception
|
std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception
|
||||||
|
|
||||||
for(const auto& arr : modifA)
|
for(const auto& arr : modifA)
|
||||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||||
for(const auto& arr : modifB)
|
for(const auto& arr : modifB)
|
||||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||||
for(const auto& arr : modifC)
|
for(const auto& arr : modifC)
|
||||||
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
|
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
|
||||||
|
|
||||||
// first step for a array
|
// first step for a array
|
||||||
if(!whatToDoWithA.empty())
|
if(!whatToDoWithA.empty())
|
||||||
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
|
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
|
||||||
// first step for b array
|
// first step for b array
|
||||||
if(!whatToDoWithB.empty())
|
if(!whatToDoWithB.empty())
|
||||||
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
|
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
|
||||||
// rest steps for a array
|
// rest steps for a array
|
||||||
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
||||||
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
||||||
// rest steps for b array
|
// rest steps for b array
|
||||||
for(int i = 1; i < whatToDoWithB.size(); ++i)
|
for(int i = 1; i < whatToDoWithB.size(); ++i)
|
||||||
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
|
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
|
||||||
|
|
||||||
// now work with c array
|
// now work with c array
|
||||||
std::vector<NDArray*> cArrs = {c};
|
std::vector<NDArray*> cArrs = {c};
|
||||||
if(!whatToDoWithC.empty()) {
|
if(!whatToDoWithC.empty()) {
|
||||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? cArrs[i]->permute(modifC[i]) : cArrs[i]->reshape(c->ordering(), modifC[i]); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||||
|
@ -152,18 +151,21 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) {
|
NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) {
|
||||||
|
|
||||||
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
|
||||||
std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception
|
std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception
|
||||||
|
|
||||||
for(const auto& arr : modifA)
|
for(const auto& arr : modifA)
|
||||||
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
|
||||||
for(const auto& arr : modifB)
|
for(const auto& arr : modifB)
|
||||||
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
|
||||||
|
|
||||||
// first step for a array
|
// first step for a array
|
||||||
if(!whatToDoWithA.empty())
|
if(!whatToDoWithA.empty())
|
||||||
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
|
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
|
||||||
// first step for b array
|
// first step for b array
|
||||||
if(!whatToDoWithB.empty())
|
if(!whatToDoWithB.empty())
|
||||||
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
|
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
|
||||||
// rest steps for a array
|
// rest steps for a array
|
||||||
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
for(int i = 1; i < whatToDoWithA.size(); ++i)
|
||||||
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
|
||||||
|
@ -293,17 +295,17 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
||||||
permut[rank-1] = rank - 2;
|
permut[rank-1] = rank - 2;
|
||||||
|
|
||||||
if(transX)
|
if(transX)
|
||||||
xT = x->permute(permut);
|
xT = new NDArray(x->permute(permut));
|
||||||
|
|
||||||
if(transY)
|
if(transY)
|
||||||
yT = y->permute(permut);
|
yT = new NDArray(y->permute(permut));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases
|
if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases
|
||||||
|
|
||||||
if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case
|
if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case
|
||||||
xT = x->reshape(x->ordering(), {1, x->lengthOf()}); // please note x is not transposed in this case (since xRank=1)
|
xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1)
|
||||||
zT = z->reshape(z->ordering(), {1, z->lengthOf()});
|
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(xT, yT, zT, 1., 0.);
|
mmul(xT, yT, zT, 1., 0.);
|
||||||
|
|
|
@ -473,19 +473,9 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
|
||||||
// FIXME: get rid of memcpy here
|
// FIXME: get rid of memcpy here
|
||||||
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
|
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
|
||||||
for (int i = 0; i < minRank; ++i)
|
for (int i = 0; i < minRank; ++i)
|
||||||
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
|
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
|
||||||
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
||||||
|
|
||||||
// nullify zero axis
|
|
||||||
for (int e = 0; e < maxRank; e++)
|
|
||||||
if (maxShapeInfo[e+1] == 0)
|
|
||||||
tmpShapeInfo[e+1] = 0;
|
|
||||||
|
|
||||||
int delta = maxRank - minRank;
|
|
||||||
for (int e = minRank - 1; e >= 0; e--)
|
|
||||||
if (minShapeInfo[e + 1] == 0)
|
|
||||||
tmpShapeInfo[e + 1 + delta] = 0;
|
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
||||||
|
|
||||||
if (shape::isEmpty(max) || shape::isEmpty(min)) {
|
if (shape::isEmpty(max) || shape::isEmpty(min)) {
|
||||||
|
|
|
@ -40,7 +40,7 @@ namespace nd4j {
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
__host__
|
__host__
|
||||||
#endif
|
#endif
|
||||||
void Logger::printv(const char *format, std::vector<int>& vec) {
|
void Logger::printv(const char *format, const std::vector<int>& vec) {
|
||||||
printf("%s: {", format);
|
printf("%s: {", format);
|
||||||
for(int e = 0; e < vec.size(); e++) {
|
for(int e = 0; e < vec.size(); e++) {
|
||||||
auto v = vec[e];
|
auto v = vec[e];
|
||||||
|
@ -55,7 +55,7 @@ namespace nd4j {
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
__host__
|
__host__
|
||||||
#endif
|
#endif
|
||||||
void Logger::printv(const char *format, std::vector<Nd4jLong>& vec) {
|
void Logger::printv(const char *format, const std::vector<Nd4jLong>& vec) {
|
||||||
printf("%s: {", format);
|
printf("%s: {", format);
|
||||||
for(int e = 0; e < vec.size(); e++) {
|
for(int e = 0; e < vec.size(); e++) {
|
||||||
auto v = vec[e];
|
auto v = vec[e];
|
||||||
|
|
|
@ -55,8 +55,8 @@ namespace nd4j {
|
||||||
|
|
||||||
static void _CUDA_H info(const char *format, ...);
|
static void _CUDA_H info(const char *format, ...);
|
||||||
|
|
||||||
static void _CUDA_H printv(const char *format, std::vector<int>& vec);
|
static void _CUDA_H printv(const char *format, const std::vector<int>& vec);
|
||||||
static void _CUDA_H printv(const char *format, std::vector<Nd4jLong>& vec);
|
static void _CUDA_H printv(const char *format, const std::vector<Nd4jLong>& vec);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1023,23 +1023,6 @@ namespace shape {
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
||||||
|
|
||||||
/**
|
|
||||||
* insert dimension at shape[axis] position
|
|
||||||
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, dimension = 10 result is -> shape = {2,10,4,5}
|
|
||||||
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 3, dimension = 10 result is -> shape = {2,4,5,10}
|
|
||||||
* so be careful and provide shape buffer with enough (at least rank+1) length
|
|
||||||
* axis should be within [0, rank] range
|
|
||||||
*/
|
|
||||||
ND4J_EXPORT _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* erase dimension at shape[axis] position
|
|
||||||
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, result is -> shape = {2,5}
|
|
||||||
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 2, result is -> shape = {2,4}
|
|
||||||
* axis should be within [0, rank-1] range
|
|
||||||
*/
|
|
||||||
ND4J_EXPORT _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4932,21 +4915,6 @@ INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
INLINEDEF _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension) {
|
|
||||||
|
|
||||||
for (int i = rank; i > axis; --i)
|
|
||||||
shape[i] = shape[i - 1];
|
|
||||||
|
|
||||||
shape[axis] = dimension;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
INLINEDEF _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis) {
|
|
||||||
|
|
||||||
for (int i = axis; i < rank - 1; ++i)
|
|
||||||
shape[i] = shape[i + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,8 +244,9 @@ namespace functions {
|
||||||
auto xi = x + threadOffset;
|
auto xi = x + threadOffset;
|
||||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||||
|
|
||||||
for (Nd4jLong i = 0; i < ulen; i++)
|
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||||
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
|
||||||
|
}
|
||||||
|
|
||||||
PRAGMA_OMP_CRITICAL
|
PRAGMA_OMP_CRITICAL
|
||||||
startingVal = OpType::update(startingVal, local, extraParams);
|
startingVal = OpType::update(startingVal, local, extraParams);
|
||||||
|
|
|
@ -122,7 +122,7 @@ namespace functions {
|
||||||
|
|
||||||
tadLength = shape::length(tadOnlyShapeInfo);
|
tadLength = shape::length(tadOnlyShapeInfo);
|
||||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
numTads = shape::length(yShapeInfo) / tadLength;
|
||||||
xEWS = shape::elementWiseStride(xShapeInfo);
|
xEWS = shape::elementWiseStride(xShapeInfo);
|
||||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,12 +21,165 @@
|
||||||
|
|
||||||
#include <ops/specials_cuda.h>
|
#include <ops/specials_cuda.h>
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
|
auto x = static_cast<X*>(vx);
|
||||||
|
auto y = static_cast<Y*>(vy);
|
||||||
|
|
||||||
|
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
int half = window>>1;
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLength;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLength = shape::length(xShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//for (int i = 0; i < length; i+= window)
|
||||||
|
/*
|
||||||
|
if window == 4;
|
||||||
|
iterations will be: 0; 4; 8; 12; 16; 20
|
||||||
|
if gridDim = 3;
|
||||||
|
on first iteration we'll have: 0; 4; 8;
|
||||||
|
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
|
||||||
|
*/
|
||||||
|
int firstPosition;
|
||||||
|
int firstStep;
|
||||||
|
int secondPosition;
|
||||||
|
int secondStep;
|
||||||
|
|
||||||
|
int WARP_SIZE = 32;
|
||||||
|
int numWarps = (gridDim.x * blockDim.x) / 32;
|
||||||
|
int warpId = tid / WARP_SIZE;
|
||||||
|
int warpIdx = tid % WARP_SIZE;
|
||||||
|
|
||||||
|
if (half >= 128) {
|
||||||
|
firstPosition = blockIdx.x * window;
|
||||||
|
firstStep = gridDim.x * window;
|
||||||
|
|
||||||
|
secondPosition = threadIdx.x;
|
||||||
|
secondStep = blockDim.x;
|
||||||
|
} else if (half >= 32) {
|
||||||
|
firstPosition = warpId * window;
|
||||||
|
firstStep = numWarps * window;
|
||||||
|
|
||||||
|
secondPosition = warpIdx;
|
||||||
|
secondStep = WARP_SIZE;
|
||||||
|
} else {
|
||||||
|
firstPosition = tid * window;
|
||||||
|
firstStep = blockDim.x * gridDim.x * window;
|
||||||
|
|
||||||
|
secondPosition = 0;
|
||||||
|
secondStep = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = firstPosition; i < length; i += firstStep) {
|
||||||
|
for (int j = secondPosition; j < half; j += secondStep) {
|
||||||
|
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||||
|
int ij = i+j;
|
||||||
|
if (it < length && ij < length ) {
|
||||||
|
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
|
||||||
|
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
|
||||||
|
|
||||||
|
Y v0 = y[posIJ];
|
||||||
|
Y v1 = y[posIT];
|
||||||
|
|
||||||
|
if(!descending == (v0 > v1)) {
|
||||||
|
y[posIJ] = v1;
|
||||||
|
y[posIT] = v0;
|
||||||
|
|
||||||
|
X xtemp = x[posIJ];
|
||||||
|
x[posIJ] = x[posIT];
|
||||||
|
x[posIT] = xtemp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
|
auto x = static_cast<X*>(vx);
|
||||||
|
auto y = static_cast<Y*>(vy);
|
||||||
|
|
||||||
|
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
int half = window>>1;
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLength;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLength = shape::length(xShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//for (int i = 0; i < length; i+= window)
|
||||||
|
/*
|
||||||
|
if window == 4;
|
||||||
|
iterations will be: 0; 4; 8; 12; 16; 20
|
||||||
|
if gridDim = 3;
|
||||||
|
on first iteration we'll have: 0; 4; 8;
|
||||||
|
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
|
||||||
|
*/
|
||||||
|
int firstPosition;
|
||||||
|
int firstStep;
|
||||||
|
int secondPosition;
|
||||||
|
int secondStep;
|
||||||
|
|
||||||
|
int WARP_SIZE = 32;
|
||||||
|
int numWarps = (gridDim.x * blockDim.x) / 32;
|
||||||
|
int warpId = tid / WARP_SIZE;
|
||||||
|
int warpIdx = tid % WARP_SIZE;
|
||||||
|
|
||||||
|
if (half >= 128) {
|
||||||
|
firstPosition = blockIdx.x * window;
|
||||||
|
firstStep = gridDim.x * window;
|
||||||
|
|
||||||
|
secondPosition = threadIdx.x;
|
||||||
|
secondStep = blockDim.x;
|
||||||
|
} else if (half >= 32) {
|
||||||
|
firstPosition = warpId * window;
|
||||||
|
firstStep = numWarps * window;
|
||||||
|
|
||||||
|
secondPosition = warpIdx;
|
||||||
|
secondStep = WARP_SIZE;
|
||||||
|
} else {
|
||||||
|
firstPosition = tid * window;
|
||||||
|
firstStep = blockDim.x * gridDim.x * window;
|
||||||
|
|
||||||
|
secondPosition = 0;
|
||||||
|
secondStep = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int i = firstPosition; i < length; i += firstStep) {
|
||||||
|
for (int j = secondPosition; j < half; j += secondStep) {
|
||||||
|
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||||
|
int ij = i+j;
|
||||||
|
if (it < length && ij < length ) {
|
||||||
|
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
|
||||||
|
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
|
||||||
|
|
||||||
|
X v0 = x[posIJ];
|
||||||
|
X v1 = x[posIT];
|
||||||
|
|
||||||
|
if(!descending == (v0 > v1)) {
|
||||||
|
x[posIJ] = v1;
|
||||||
|
x[posIT] = v0;
|
||||||
|
|
||||||
|
Y ytemp = y[posIJ];
|
||||||
|
y[posIJ] = y[posIT];
|
||||||
|
y[posIT] = ytemp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__
|
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
|
||||||
|
|
||||||
auto x = static_cast<T*>(vx);
|
auto x = static_cast<T*>(vx);
|
||||||
|
|
||||||
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
int tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
@ -85,8 +238,8 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
|
||||||
int it = (reverse) ? i + j + half : i + window - j - 1;
|
int it = (reverse) ? i + j + half : i + window - j - 1;
|
||||||
int ij = i+j;
|
int ij = i+j;
|
||||||
if (it < length && ij < length ) {
|
if (it < length && ij < length ) {
|
||||||
int posIT = getDevicePosition(xShapeInfo,it, xLength);
|
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
|
||||||
int posIJ = getDevicePosition(xShapeInfo, ij, xLength);
|
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
|
||||||
|
|
||||||
shmem[threadIdx.x] = x[posIJ];
|
shmem[threadIdx.x] = x[posIJ];
|
||||||
shmem[threadIdx.x + blockDim.x] = x[posIT];
|
shmem[threadIdx.x + blockDim.x] = x[posIT];
|
||||||
|
@ -100,18 +253,22 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
|
||||||
|
|
||||||
bitonicArbitraryStepKernel<T>(vx, xShapeInfo, window, length, reverse, descending);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
|
|
||||||
execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending);
|
execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "bitonicArbitrary(...) failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
|
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
|
||||||
|
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
|
@ -21,9 +21,119 @@
|
||||||
|
|
||||||
#include <ops/specials_cuda.h>
|
#include <ops/specials_cuda.h>
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
|
||||||
|
auto x = static_cast<X*>(vx);
|
||||||
|
auto y = static_cast<Y*>(vy);
|
||||||
|
|
||||||
|
unsigned int i, ixj; /* Sorting partners: i and ixj */
|
||||||
|
i = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLength;
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
xLength = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
|
||||||
|
if (i >= length)
|
||||||
|
return;
|
||||||
|
|
||||||
|
ixj = i^j;
|
||||||
|
|
||||||
|
/* The threads with the lowest ids sort the array. */
|
||||||
|
if ((ixj)>i) {
|
||||||
|
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
|
||||||
|
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
|
||||||
|
|
||||||
|
if ((i&k)==0) {
|
||||||
|
/* Sort ascending */
|
||||||
|
if (!descending == (y[posI]>y[posIXJ])) {
|
||||||
|
/* exchange(i,ixj); */
|
||||||
|
X temp = x[posI];
|
||||||
|
x[posI] = x[posIXJ];
|
||||||
|
x[posIXJ] = temp;
|
||||||
|
|
||||||
|
Y ytemp = y[posI];
|
||||||
|
y[posI] = y[posIXJ];
|
||||||
|
y[posIXJ] = ytemp;
|
||||||
|
}
|
||||||
|
} else if ((i&k)!=0) {
|
||||||
|
/* Sort descending */
|
||||||
|
if (!descending == (y[posI]<y[posIXJ])) {
|
||||||
|
/* exchange(i,ixj); */
|
||||||
|
X temp = x[posI];
|
||||||
|
x[posI] = x[posIXJ];
|
||||||
|
x[posIXJ] = temp;
|
||||||
|
|
||||||
|
Y ytemp = y[posI];
|
||||||
|
y[posI] = y[posIXJ];
|
||||||
|
y[posIXJ] = ytemp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
|
||||||
|
auto x = static_cast<X*>(vx);
|
||||||
|
auto y = static_cast<Y*>(vy);
|
||||||
|
|
||||||
|
unsigned int i, ixj; /* Sorting partners: i and ixj */
|
||||||
|
i = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xLength;
|
||||||
|
if (threadIdx.x == 0)
|
||||||
|
xLength = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
|
||||||
|
if (i >= length)
|
||||||
|
return;
|
||||||
|
|
||||||
|
ixj = i^j;
|
||||||
|
|
||||||
|
/* The threads with the lowest ids sort the array. */
|
||||||
|
if ((ixj)>i) {
|
||||||
|
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
|
||||||
|
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
|
||||||
|
|
||||||
|
if ((i&k)==0) {
|
||||||
|
/* Sort ascending */
|
||||||
|
if (!descending == (x[posI]>x[posIXJ])) {
|
||||||
|
/* exchange(i,ixj); */
|
||||||
|
X temp = x[posI];
|
||||||
|
x[posI] = x[posIXJ];
|
||||||
|
x[posIXJ] = temp;
|
||||||
|
|
||||||
|
Y ytemp = y[posI];
|
||||||
|
y[posI] = y[posIXJ];
|
||||||
|
y[posIXJ] = ytemp;
|
||||||
|
}
|
||||||
|
} else if ((i&k)!=0) {
|
||||||
|
/* Sort descending */
|
||||||
|
if (!descending == (x[posI]<x[posIXJ])) {
|
||||||
|
/* exchange(i,ixj); */
|
||||||
|
X temp = x[posI];
|
||||||
|
x[posI] = x[posIXJ];
|
||||||
|
x[posIXJ] = temp;
|
||||||
|
|
||||||
|
Y ytemp = y[posI];
|
||||||
|
y[posI] = y[posIXJ];
|
||||||
|
y[posIXJ] = ytemp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
__global__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
|
||||||
auto x = static_cast<T*>(vx);
|
auto x = static_cast<T*>(vx);
|
||||||
|
|
||||||
|
@ -44,8 +154,8 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
|
||||||
|
|
||||||
/* The threads with the lowest ids sort the array. */
|
/* The threads with the lowest ids sort the array. */
|
||||||
if ((ixj)>i) {
|
if ((ixj)>i) {
|
||||||
int posI = getDevicePosition(xShapeInfo, i, xLength);
|
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
|
||||||
int posIXJ = getDevicePosition(xShapeInfo, ixj, xLength);
|
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
|
||||||
|
|
||||||
if ((i&k)==0) {
|
if ((i&k)==0) {
|
||||||
/* Sort ascending */
|
/* Sort ascending */
|
||||||
|
@ -69,16 +179,23 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void execBitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
bitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
|
||||||
bitonicSortStepKernel<T>(vx, xShapeInfo, j, k, length, descending);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template <typename X, typename Y>
|
||||||
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
|
__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||||
execBitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "bitonicSortStep(...) failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
|
||||||
|
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
|
@ -16,15 +16,86 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
// @author Yurii Shyrma, created on 28.11.2018
|
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/specials_cuda.h>
|
#include <ops/specials_cuda.h>
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
|
||||||
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
|
int *dimension, int dimensionLength,
|
||||||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
|
bool descending) {
|
||||||
|
|
||||||
|
auto x = static_cast<X*>(vx);
|
||||||
|
auto y = static_cast<Y*>(vy);
|
||||||
|
|
||||||
|
__shared__ int xLength;
|
||||||
|
__shared__ int xTadLength;
|
||||||
|
__shared__ int numTads;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLength = shape::length(xShapeInfo);
|
||||||
|
xTadLength = shape::length(tadShapeInfo);
|
||||||
|
numTads = xLength / xTadLength;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||||
|
auto dx = x + tadOffsets[r];
|
||||||
|
auto dy = y + tadOffsets[r];
|
||||||
|
|
||||||
|
// this is general loop, we go uncached
|
||||||
|
int iterations = xTadLength;
|
||||||
|
|
||||||
|
for (int i = 0; i < iterations; i++) {
|
||||||
|
|
||||||
|
if (i % 2 == 0) {
|
||||||
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
|
auto top = 2 * tid + 1;
|
||||||
|
if (top < xTadLength) {
|
||||||
|
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||||
|
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||||
|
|
||||||
|
if (!descending == (dx[t0] > dx[t1])) {
|
||||||
|
X dt0 = dx[t0];
|
||||||
|
dx[t0] = dx[t1];
|
||||||
|
dx[t1] = dt0;
|
||||||
|
|
||||||
|
Y dy0 = dy[t0];
|
||||||
|
dy[t0] = dy[t1];
|
||||||
|
dy[t1] = dy0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
|
auto top = 2 * tid + 2;
|
||||||
|
if (top < xTadLength) {
|
||||||
|
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||||
|
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||||
|
|
||||||
|
if (!descending == (dx[t0] > dx[t1])) {
|
||||||
|
X dt0 = dx[t0];
|
||||||
|
dx[t0] = dx[t1];
|
||||||
|
dx[t1] = dt0;
|
||||||
|
|
||||||
|
Y dy0 = dy[t0];
|
||||||
|
dy[t0] = dy[t1];
|
||||||
|
dy[t1] = dy0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__
|
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
bool descending) {
|
bool descending) {
|
||||||
|
@ -56,7 +127,7 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
int iterations = xTadLength;
|
int iterations = xTadLength;
|
||||||
if (cached) {
|
if (cached) {
|
||||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
|
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
|
||||||
shmem[tid] = dx[t0];
|
shmem[tid] = dx[t0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,8 +141,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
auto top = 2 * tid + 1;
|
auto top = 2 * tid + 1;
|
||||||
if (top < xTadLength) {
|
if (top < xTadLength) {
|
||||||
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
|
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||||
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
|
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||||
|
|
||||||
if (!descending == (dx[t0] > dx[t1])) {
|
if (!descending == (dx[t0] > dx[t1])) {
|
||||||
T dt0 = dx[t0];
|
T dt0 = dx[t0];
|
||||||
|
@ -84,8 +155,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
auto top = 2 * tid + 2;
|
auto top = 2 * tid + 2;
|
||||||
if (top < xTadLength) {
|
if (top < xTadLength) {
|
||||||
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
|
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
|
||||||
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
|
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
|
||||||
|
|
||||||
if (!descending == (dx[t0] > dx[t1])) {
|
if (!descending == (dx[t0] > dx[t1])) {
|
||||||
T dt0 = dx[t0];
|
T dt0 = dx[t0];
|
||||||
|
@ -102,23 +173,13 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
if (cached) {
|
if (cached) {
|
||||||
dx = x + tadOffsets[r];
|
dx = x + tadOffsets[r];
|
||||||
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
|
||||||
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
|
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
|
||||||
dx[t0] = shmem[tid];
|
dx[t0] = shmem[tid];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
|
||||||
int *dimension, int dimensionLength,
|
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
|
||||||
bool descending) {
|
|
||||||
|
|
||||||
oesTadKernel<T>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
|
__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
|
||||||
|
@ -128,6 +189,18 @@ __host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
|
||||||
bool descending) {
|
bool descending) {
|
||||||
|
|
||||||
execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "oesTad(...) failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y>
|
||||||
|
__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream,
|
||||||
|
void *vx, Nd4jLong *xShapeInfo,
|
||||||
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
|
int *dimension, int dimensionLength,
|
||||||
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
|
bool descending) {
|
||||||
|
|
||||||
|
execOesTadKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES);
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
|
@ -65,13 +65,7 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
|
||||||
REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
||||||
// ***** end of validation ***** //
|
// ***** end of validation ***** //
|
||||||
|
|
||||||
if(alphaShape != expectedAlphaShape)
|
helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output);
|
||||||
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
|
|
||||||
|
|
||||||
helpers::prelu(block.launchContext(), *input, *alpha, *output);
|
|
||||||
|
|
||||||
if(alphaShape != expectedAlphaShape)
|
|
||||||
delete alpha;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -128,9 +122,10 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
|
||||||
REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
|
||||||
// ***** end of validation ***** //
|
// ***** end of validation ***** //
|
||||||
|
|
||||||
|
|
||||||
if(alphaShape != expectedAlphaShape) {
|
if(alphaShape != expectedAlphaShape) {
|
||||||
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
|
alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape));
|
||||||
dLdA = dLdA->reshape(dLdA->ordering(), expectedAlphaShape);
|
dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA);
|
helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA);
|
||||||
|
|
|
@ -29,7 +29,6 @@ namespace nd4j {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
nd4j_printf("Comparing [%f] to [%f]\n", x->e<float>(0), y->e<float>(0));
|
|
||||||
if (x->e<float>(0) < y->e<float>(0))
|
if (x->e<float>(0) < y->e<float>(0))
|
||||||
return ND4J_STATUS_TRUE;
|
return ND4J_STATUS_TRUE;
|
||||||
else
|
else
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace nd4j {
|
||||||
auto condition = INPUT_VARIABLE(0);
|
auto condition = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
if (z->isEmpty())
|
if (z->isEmpty())
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
|
|
||||||
if (block.width() == 3) {
|
if (block.width() == 3) {
|
||||||
auto x = INPUT_VARIABLE(1);
|
auto x = INPUT_VARIABLE(1);
|
||||||
|
@ -44,12 +44,10 @@ namespace nd4j {
|
||||||
// FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y
|
// FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y
|
||||||
for (int e = 0; e < condition->lengthOf(); e++) {
|
for (int e = 0; e < condition->lengthOf(); e++) {
|
||||||
if (y->isR()) {
|
if (y->isR()) {
|
||||||
auto r = !condition->e<bool>(e) ? y->e<double>(e)
|
auto r = !condition->e<bool>(e) ? y->e<double>(e) : x->e<double>(e);
|
||||||
: x->e<double>(e);
|
|
||||||
z->p(e, r);
|
z->p(e, r);
|
||||||
} else {
|
} else {
|
||||||
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e)
|
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e) : x->e<Nd4jLong>(e);
|
||||||
: x->e<Nd4jLong>(e);
|
|
||||||
z->p(e, r);
|
z->p(e, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -86,7 +84,7 @@ namespace nd4j {
|
||||||
|
|
||||||
helpers::_where(block.launchContext(), *condition, *output, block.workspace());
|
helpers::_where(block.launchContext(), *condition, *output, block.workspace());
|
||||||
}
|
}
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(Where) {
|
DECLARE_SHAPE_FN(Where) {
|
||||||
|
|
|
@ -120,7 +120,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(where_np) {
|
DECLARE_SHAPE_FN(where_np) {
|
||||||
|
|
|
@ -81,11 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
||||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||||
|
|
||||||
delete inputReshaped;
|
|
||||||
delete outputReshaped;
|
|
||||||
delete weightsReshaped;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -217,13 +213,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||||
|
|
||||||
delete inputReshaped;
|
|
||||||
delete gradIReshaped;
|
|
||||||
delete gradOReshaped;
|
|
||||||
delete weightsReshaped;
|
|
||||||
delete gradWReshaped;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,10 +151,10 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||||
|
|
||||||
std::vector<int> permutForOutput;
|
std::vector<int> permutForOutput;
|
||||||
|
|
||||||
if(!isNCDHW)
|
if (isNCDHW)
|
||||||
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
else
|
|
||||||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||||
|
else
|
||||||
|
input = new NDArray(input->permute({0,4,1,2,3}));
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||||
|
@ -447,21 +447,23 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
std::vector<int> gradOaxesForDot;
|
std::vector<int> gradOaxesForDot;
|
||||||
|
|
||||||
if(!isNDHWC) {
|
if(!isNDHWC) {
|
||||||
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradI = gradI->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW
|
gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW
|
||||||
|
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
}
|
}
|
||||||
else
|
else {
|
||||||
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
|
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
|
||||||
|
}
|
||||||
|
|
||||||
// ----- calculation of gradW and gradB ----- //
|
// ----- calculation of gradW and gradB ----- //
|
||||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||||
|
|
||||||
|
//----- calculation of gradO -----//
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW
|
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(!isNCHW)
|
if(!isNCHW)
|
||||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
@ -211,8 +211,9 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
|
|
||||||
// -----prepare permutation arrays and axes for dot product ----- //
|
// -----prepare permutation arrays and axes for dot product ----- //
|
||||||
std::vector<int> inputAxesForDot;
|
std::vector<int> inputAxesForDot;
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
inputAxesForDot = {0, 1, 2}; // bS, iH, iW
|
inputAxesForDot = {0, 1, 2}; // bS, iH, iW
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -228,7 +229,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
@ -237,7 +238,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
if(!isNCHW)
|
if(!isNCHW)
|
||||||
delete gradO;
|
delete gradO;
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(deconv2d_bp) {
|
DECLARE_SHAPE_FN(deconv2d_bp) {
|
||||||
|
|
|
@ -39,9 +39,9 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||||
|
|
||||||
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
|
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) depth
|
||||||
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
|
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) height
|
||||||
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
|
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2)); // filter(kernel) width
|
||||||
int sD = INT_ARG(3); // strides depth
|
int sD = INT_ARG(3); // strides depth
|
||||||
int sH = INT_ARG(4); // strides height
|
int sH = INT_ARG(4); // strides height
|
||||||
int sW = INT_ARG(5); // strides width
|
int sW = INT_ARG(5); // strides width
|
||||||
|
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if(!isNCDHW)
|
if(!isNCDHW)
|
||||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
@ -225,8 +225,9 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
|
|
||||||
// -----prepare permutation arrays and axes for dot product ----- //
|
// -----prepare permutation arrays and axes for dot product ----- //
|
||||||
std::vector<int> inputAxesForDot;
|
std::vector<int> inputAxesForDot;
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
||||||
inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW
|
inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -240,7 +241,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -71,7 +71,7 @@ namespace ops {
|
||||||
int pad_top = 0, pad_left = 0;
|
int pad_top = 0, pad_left = 0;
|
||||||
int out_rows = 0, out_cols = 0;
|
int out_rows = 0, out_cols = 0;
|
||||||
|
|
||||||
helpers::_dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
|
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
|
||||||
|
@ -126,7 +126,7 @@ namespace ops {
|
||||||
int pad_top = 0, pad_left = 0;
|
int pad_top = 0, pad_left = 0;
|
||||||
int out_rows = 0, out_cols = 0;
|
int out_rows = 0, out_cols = 0;
|
||||||
|
|
||||||
helpers::_dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
||||||
|
|
||||||
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
|
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
|
||||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
|
||||||
|
|
|
@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
@ -71,9 +71,8 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
||||||
//output->printBuffer("output op");
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
delete output;
|
delete output;
|
||||||
}
|
}
|
||||||
|
@ -177,10 +176,11 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
@ -205,9 +205,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
delete gradI;
|
delete gradI;
|
||||||
delete gradO;
|
delete gradO;
|
||||||
}
|
}
|
||||||
// delete columns;
|
|
||||||
// delete columns2d;
|
|
||||||
// delete gradOVector;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
|
|
|
@ -61,8 +61,8 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
@ -180,9 +180,9 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
|
|
@ -59,9 +59,9 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1);
|
const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1);
|
||||||
const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2);
|
const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2);
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||||
|
@ -72,7 +72,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
delete output;
|
delete output;
|
||||||
}
|
}
|
||||||
|
@ -175,9 +175,9 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
@ -203,9 +203,6 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
delete gradI;
|
delete gradI;
|
||||||
delete gradO;
|
delete gradO;
|
||||||
}
|
}
|
||||||
// delete columns;
|
|
||||||
// delete columns2d;
|
|
||||||
// delete gradOVector;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,8 +63,8 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
||||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
@ -182,9 +182,9 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
|
@ -211,9 +211,6 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
delete gradI;
|
delete gradI;
|
||||||
delete gradO;
|
delete gradO;
|
||||||
}
|
}
|
||||||
// delete columns;
|
|
||||||
// delete columns2d;
|
|
||||||
// delete gradOVector;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,9 +54,9 @@ namespace nd4j {
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto inY = static_cast<int>(input->sizeAt(2));
|
const auto inY = static_cast<int>(input->sizeAt(2));
|
||||||
|
@ -70,7 +70,7 @@ namespace nd4j {
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
||||||
|
|
||||||
if (!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
delete output;
|
delete output;
|
||||||
}
|
}
|
||||||
|
@ -175,9 +175,9 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
||||||
}
|
}
|
||||||
|
|
||||||
// if(isSameMode) // SAME
|
// if(isSameMode) // SAME
|
||||||
|
@ -216,10 +216,6 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
||||||
delete gradI;
|
delete gradI;
|
||||||
delete gradO;
|
delete gradO;
|
||||||
}
|
}
|
||||||
// delete columns;
|
|
||||||
// delete columns2d;
|
|
||||||
// delete gradOVector;
|
|
||||||
// delete denomVec;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
|
||||||
auto weightsBroad = weights;
|
auto weightsBroad = weights;
|
||||||
if(!weights->isScalar() && !weights->isSameShape(&E)) {
|
if(!weights->isScalar() && !weights->isSameShape(&E)) {
|
||||||
if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1)
|
if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1)
|
||||||
weightsBroad = weights->reshape(weights->ordering(), {weights->lengthOf()});
|
weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()}));
|
||||||
else
|
else
|
||||||
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
|
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ namespace ops {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(mask != nullptr){
|
if(mask != nullptr){
|
||||||
NDArray* reshapedMask;
|
NDArray reshapedMask;
|
||||||
if(weights->rankOf() == 4){
|
if(weights->rankOf() == 4){
|
||||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
||||||
}else{
|
}else{
|
||||||
|
@ -87,8 +87,7 @@ namespace ops {
|
||||||
// before going through the softmax, we effectively push all masked positions to zero after softmax.
|
// before going through the softmax, we effectively push all masked positions to zero after softmax.
|
||||||
//
|
//
|
||||||
// we are using 1e9 to mean effectively infinity
|
// we are using 1e9 to mean effectively infinity
|
||||||
*weights += (*reshapedMask - 1) * 1e9;
|
*weights += (reshapedMask - 1) * 1e9;
|
||||||
delete reshapedMask;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ops::softmax softmax;
|
nd4j::ops::softmax softmax;
|
||||||
|
@ -175,14 +174,13 @@ namespace ops {
|
||||||
preSoftmax /= factor;
|
preSoftmax /= factor;
|
||||||
|
|
||||||
if(mask != nullptr){
|
if(mask != nullptr){
|
||||||
NDArray* reshapedMask;
|
NDArray reshapedMask;
|
||||||
if(preSoftmax.rankOf() == 4){
|
if(preSoftmax.rankOf() == 4){
|
||||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
|
||||||
}else{
|
}else{
|
||||||
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1});
|
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1});
|
||||||
}
|
}
|
||||||
preSoftmax += (*reshapedMask - 1) * 1e9;
|
preSoftmax += (reshapedMask - 1) * 1e9;
|
||||||
delete reshapedMask;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray weights('c', weightShape, values->dataType(), block.launchContext());
|
NDArray weights('c', weightShape, values->dataType(), block.launchContext());
|
||||||
|
|
|
@ -70,7 +70,7 @@ namespace nd4j {
|
||||||
float beta = T_ARG(2);
|
float beta = T_ARG(2);
|
||||||
int depth = INT_ARG(0);
|
int depth = INT_ARG(0);
|
||||||
|
|
||||||
helpers::lrnBP(*input, *gradO, *gradI, depth, bias, alpha, beta);
|
helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,9 +98,9 @@ namespace ops {
|
||||||
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
||||||
|
|
||||||
// Apply Attention
|
// Apply Attention
|
||||||
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
|
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
||||||
nd4j::ops::dot_product_attention attention;
|
nd4j::ops::dot_product_attention attention;
|
||||||
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
||||||
|
|
||||||
// Project attention results
|
// Project attention results
|
||||||
attnResults.permutei({0, 3, 1, 2});
|
attnResults.permutei({0, 3, 1, 2});
|
||||||
|
@ -111,11 +111,9 @@ namespace ops {
|
||||||
mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {});
|
mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {});
|
||||||
projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize});
|
projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize});
|
||||||
projRes.permutei({0, 2, 1});
|
projRes.permutei({0, 2, 1});
|
||||||
output->assign(projRes);
|
|
||||||
|
|
||||||
delete projectedQueries;
|
// FIXME: bad for performance
|
||||||
delete projectedKeys;
|
output->assign(projRes);
|
||||||
delete projectedValues;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -227,9 +225,9 @@ namespace ops {
|
||||||
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
||||||
|
|
||||||
// Apply Attention
|
// Apply Attention
|
||||||
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
|
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
||||||
nd4j::ops::dot_product_attention attention;
|
nd4j::ops::dot_product_attention attention;
|
||||||
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
|
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
|
||||||
|
|
||||||
// Project attention results
|
// Project attention results
|
||||||
attnResults.permutei({0, 3, 1, 2});
|
attnResults.permutei({0, 3, 1, 2});
|
||||||
|
@ -237,31 +235,25 @@ namespace ops {
|
||||||
|
|
||||||
// dLdWo
|
// dLdWo
|
||||||
auto epsPerm = eps->permute({0, 2, 1});
|
auto epsPerm = eps->permute({0, 2, 1});
|
||||||
auto epsPostReshape = epsPerm->reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
|
||||||
nd4j::ops::matmul_bp matmulBp;
|
nd4j::ops::matmul_bp matmulBp;
|
||||||
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
|
||||||
matmulBp.execute({&attnResults, Wo, epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
|
||||||
|
|
||||||
// dLdAttn
|
// dLdAttn
|
||||||
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues->sizeAt(2)});
|
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
|
||||||
dLdPreWo.permutei({0, 2, 3, 1});
|
dLdPreWo.permutei({0, 2, 3, 1});
|
||||||
|
|
||||||
nd4j::ops::dot_product_attention_bp attentionBp;
|
nd4j::ops::dot_product_attention_bp attentionBp;
|
||||||
NDArray dLdProjectedQueries(projectedQueries->shapeInfo(), false, block.launchContext());
|
NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext());
|
||||||
NDArray dLdProjectedKeys(projectedKeys->shapeInfo(), false, block.launchContext());
|
NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext());
|
||||||
NDArray dLdProjectedValues(projectedValues->shapeInfo(), false, block.launchContext());
|
NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext());
|
||||||
attentionBp.execute({projectedQueries, projectedKeys, projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
|
attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
|
||||||
|
|
||||||
AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext());
|
AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext());
|
||||||
AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext());
|
AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext());
|
||||||
AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext());
|
AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext());
|
||||||
|
|
||||||
delete projectedQueries;
|
|
||||||
delete projectedKeys;
|
|
||||||
delete projectedValues;
|
|
||||||
delete epsPerm;
|
|
||||||
delete epsPostReshape;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!");
|
REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!");
|
||||||
}
|
}
|
||||||
|
|
||||||
*output = helpers::betaInc(block.launchContext(), *a, *b, *x);
|
helpers::betaInc(block.launchContext(), *a, *b, *x, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,10 +48,7 @@ namespace nd4j {
|
||||||
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
|
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
|
||||||
auto tArr = input->reshape(input->ordering(), shape);
|
auto tArr = input->reshape(input->ordering(), shape);
|
||||||
auto zArr = z->reshape(z->ordering(), shape);
|
auto zArr = z->reshape(z->ordering(), shape);
|
||||||
tArr->addRowVector(bias, zArr);
|
tArr.addRowVector(bias, &zArr);
|
||||||
|
|
||||||
delete tArr;
|
|
||||||
delete zArr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
STORE_RESULT(*z);
|
||||||
|
@ -87,13 +84,12 @@ namespace nd4j {
|
||||||
// cnn case
|
// cnn case
|
||||||
if (input->rankOf() == 4) {
|
if (input->rankOf() == 4) {
|
||||||
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
|
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
|
||||||
epsilonNext2d->reshapei('c', {(int) bias->lengthOf(), -1});
|
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
|
||||||
|
|
||||||
auto sum = epsilonNext2d->reduceAlongDimension(reduce::Sum, {1});
|
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
|
||||||
gradB->assign(sum);
|
gradB->assign(sum);
|
||||||
|
|
||||||
delete sum;
|
delete sum;
|
||||||
delete epsilonNext2d;
|
|
||||||
} else if (input->rankOf() == 2) {
|
} else if (input->rankOf() == 2) {
|
||||||
// regular fully-connected case
|
// regular fully-connected case
|
||||||
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
|
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_check_numerics)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto message = INPUT_VARIABLE(1);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite);
|
||||||
|
REQUIRE_TRUE(allFinite.e<bool>(0), 0, "CheckNumerics: %s", message->e<std::string>(0).c_str());
|
||||||
|
|
||||||
|
if (!block.isInplace())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(check_numerics) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(check_numerics) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, nd4j::DataType::UTF8)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -56,7 +56,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(crop_and_resize) {
|
DECLARE_SHAPE_FN(crop_and_resize) {
|
||||||
auto in = inputShape->at(0);
|
auto in = inputShape->at(1);
|
||||||
|
|
||||||
Nd4jLong outputShape[4];
|
Nd4jLong outputShape[4];
|
||||||
|
|
||||||
|
@ -77,8 +77,13 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
DECLARE_TYPES(crop_and_resize) {
|
DECLARE_TYPES(crop_and_resize) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
// ->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {FLOAT32}) // as TF
|
||||||
|
->setAllowedInputTypes(2, {ALL_INTS})
|
||||||
|
->setAllowedInputTypes(3, {ALL_INTS})
|
||||||
|
->setAllowedOutputTypes({FLOAT32}); // as TF
|
||||||
|
// ->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,9 +47,9 @@ namespace ops {
|
||||||
auto o = OUTPUT_VARIABLE(0);
|
auto o = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
if (a->lengthOf() == 3) {
|
if (a->lengthOf() == 3) {
|
||||||
helpers::_cross(block.launchContext(), a, b, o);
|
helpers::cross(block.launchContext(), a, b, o);
|
||||||
} else {
|
} else {
|
||||||
helpers::_crossBatched(block.launchContext(), a, b, o);
|
helpers::crossBatched(block.launchContext(), a, b, o);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue