Merge master to upstream (#7945)

* Shugeo strided slice zeros (#14)

* Modified strided_slice op to properly work with empty-like shapes.

* Fixed test for reduce_mean with empty-like input.

* [WIP] Last merge (#15)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks

* [WIP] Fixing outstanding issues for NLP (#9)

* Avoid using not-inited objects

* Test fixed.

* Redundant method avoided for models like FastText

* KMeans++ implementation

* KMeans++ implementation

* Disable parallel execution

* KMeans++

* Tests

* Dev branch merge (#16)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Fix some issues on master (#17)

* Fix DataVec test issue

* Fix issue with dl4j SameDiff output layer

* Dtype fix for lambda layers

* #7912 BertIterator dtype fix (use float32 not global default)

* [WIP] Next set of CUDA stuff (#7)

New CUDA implementations and improvements

* bad file

* Dev branch master merge (#23)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* SameDiff ops, TF import and fixes (#24)

* CheckNumerics tests + fixes + misc fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fake quant

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FakeQuantWithMinMaxArgs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* CheckNumerics fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Exception tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for out of scope stack allocated var use

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignore for known failing test (already logged issue)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Merge upstream to fork (#25)

* Add thousand-separator commas to TotalParams (#7915)

* Add thousand-separator commas to TotalParams

The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.

* Add thousand-separator commas to MultiLayerNetwork

Corresponding change to MultiLayerNetwork

Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>

* Update contributing and issue/PR templates (#7934)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix link to AdaDelta paper (#7942)

Fix link to AdaDelta paper hosted on matthewzeiler.com

Signed-off-by: Jxtps

* Fixes, and ignores for known/logged failing issues (#7943)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff + DL4J/SameDiff: Multiple fixes (#28)

* #7919 HDF5 attribute buffer length fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7909 Arbiter constructor exception ux improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7925 RNN output layer length checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Add listener for validating inputs are not incorrectly modified

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Integrate NonInplaceValidationListener into tests

* #7844 DL4J SameDiff fixes for variable minibatch size

* DL4J SameDiff fixes - ensure gradient for input placeholder is available

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweaks to ExternalErrorsFunction - use placeholders, make more robust

* Another fix

* More fixes

* More SameDiff/DL4J fixes

* Scope out scalar array creation in BaseScalarOp

* Remove debug code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Final dev branch merge (#29)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* [WIP] Multiple dataset iterators (#27)

* Splitting dataset into arbitrary number

* Fixes

* Multiple split of iterator

* Test

* Test

* Some fixes

* signature change

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for sequential use of DataSetIteratorSplitter

Signed-off-by: raver119 <raver119@gmail.com>

* Fixes

* Fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* couple of assertions tweaked

Signed-off-by: raver119 <raver119@gmail.com>

* MDS splitter test :/

Signed-off-by: raver119 <raver119@gmail.com>

* Minor refactoring

* Multi dataset

* Some fixes

* More tests

* Small number of test fixes/improvements (failures on CI) (#31)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] More CUDA stuff (#26)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* LRN BP CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less memory

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed bug with crop_and_resize op helper.

* get rid of unnecessary index-calculation dunction

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed sort with nth_element cuda-based helper.

* Refactored nth_element.

* Refactored nth_element op and tests.

* Modified usage of dim array with sortTad routine.

* Refactored main routine of helper for non_max_image_suppression op.

* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.

* fix vol2col cuda kernel

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* topK concept

Signed-off-by: raver119 <raver119@gmail.com>

* unsorted topK with scanWitdh of 1

Signed-off-by: raver119 <raver119@gmail.com>

* correct vol2col tests

* sorted/unsorted topK

Signed-off-by: raver119 <raver119@gmail.com>

* implementation and fixing col2im/col2vol

* Corrected usage flags with input/output with reverse op.

* dup is const now

Signed-off-by: raver119 <raver119@gmail.com>

* percentile op

Signed-off-by: raver119 <raver119@gmail.com>

* group tests for mapool2d

Signed-off-by: Yurii <yurii@skymind.io>

* special test for george

Signed-off-by: raver119 <raver119@gmail.com>

* less threads for sortTad

Signed-off-by: raver119 <raver119@gmail.com>

* provide conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* remove auther in sort tad kernel code

Signed-off-by: Yurii <yurii@skymind.io>

* provide depthwise_conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - max_pooling_with_argmax
- null check for special use

Signed-off-by: raver119 <raver119@gmail.com>

* dts cuda

Signed-off-by: raver119 <raver119@gmail.com>

* provide sconv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* std cuda

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op to conform TF implementation.

* Improved suppression helper.

* provide pooling3d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* more of minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* (bi)dynamic_rnn

Signed-off-by: raver119 <raver119@gmail.com>

* templates init order

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op.

* Added cuda kernel for non_max_suppression.

* CPU sort by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value tests

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminate compiler error with cuda implementation.

* - repaired gradCheck in cuda
- provide conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* missed signature

Signed-off-by: raver119 <raver119@gmail.com>

* provide depthwise_conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of lup helper with cuda kernel. Initial commit.

* further work on backprops for convolutions

Signed-off-by: Yurii <yurii@skymind.io>

* CUDA linear sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA tad sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* start providing of backprop for pooling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* Added atomicAdd for bool datatype.

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition scalar CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* important comment

Signed-off-by: raver119 <raver119@gmail.com>

* fix pooling2d/3d backprop helpers

Signed-off-by: Yurii <yurii@skymind.io>

* Added non-linear test with dynamic_partition.

* Improved test for dynamic_partition.

* dynamic_partition TAD concept

Signed-off-by: raver119 <raver119@gmail.com>

* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix

Signed-off-by: raver119 <raver119@gmail.com>

* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* dynamic_stitch CUDA vector case

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case impl

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for dynamic_stitch 3D-4D cases.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed type check for dynamic stitch.

* min/max bp

Signed-off-by: raver119 <raver119@gmail.com>

* rewrite code for upsampling2d/3d cpu

Signed-off-by: Yurii <yurii@skymind.io>

* reduce min/max/norm_max bp

Signed-off-by: raver119 <raver119@gmail.com>

* lup implementation. Additional enhancements.

* provide code for upsamling2d/3d backprop

Signed-off-by: Yurii <yurii@skymind.io>

* weightedCrossEntropyWithLogits

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed template math atomicMul for 64bit ints.

* Refactored dynamic_partition_bp op.

* inverseBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* DynamicPartitionBP test datatype fixed.

* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA

Signed-off-by: raver119 <raver119@gmail.com>
master
Alex Black 2019-06-28 01:37:04 +10:00 committed by raver119
parent cae4fc9760
commit 1170827c18
331 changed files with 17959 additions and 7363 deletions

View File

@ -31,7 +31,7 @@ public class TaskCreatorProvider {
} }
return c.newInstance(); return c.newInstance();
} catch (Exception e){ } catch (Exception e){
throw new RuntimeException("Could not create new instance of task creator class: " + c, e); throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
} }
} }

View File

@ -83,7 +83,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
(Class<? extends DataSetIteratorFactory>) Class.forName(value); (Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance(); return clazz.newInstance();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
} }
} }
} }

View File

@ -79,7 +79,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
(Class<? extends DataSetIteratorFactory>) Class.forName(value); (Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance(); return clazz.newInstance();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
} }
} }
} }

View File

@ -54,7 +54,7 @@ public abstract class BaseNetScoreFunction implements ScoreFunction {
ds.configure(dataSourceProperties); ds.configure(dataSourceProperties);
} }
} catch (Exception e){ } catch (Exception e){
throw new RuntimeException(e); throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e);
} }
return score(model, ds.testData()); return score(model, ds.testData());
} }

View File

@ -188,10 +188,15 @@ public class ComputationGraphTaskCreator implements TaskCreator {
//For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both //For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both
MultiDataSetIterator iterator; MultiDataSetIterator iterator;
if(dataSource != null){ if(dataSource != null){
try {
DataSource dsInstance = dataSource.newInstance(); DataSource dsInstance = dataSource.newInstance();
if (dataSourceProperties != null) if (dataSourceProperties != null)
dsInstance.configure(dataSourceProperties); dsInstance.configure(dataSourceProperties);
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData()); iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
} catch (Exception e){
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
" - no zero-arg constructor?",e);
}
} else { } else {
iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters())); iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters()));
} }

View File

@ -190,7 +190,8 @@ public class MultiLayerNetworkTaskCreator implements TaskCreator {
try{ try{
dsInstance = dataSource.newInstance(); dsInstance = dataSource.newInstance();
} catch (Exception e){ } catch (Exception e){
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName()); throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
" - no zero-arg constructor?",e);
} }
if(dataSourceProperties != null) if(dataSourceProperties != null)
dsInstance.configure(dataSourceProperties); dsInstance.configure(dataSourceProperties);

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
@ -78,14 +79,14 @@ public class TestNDArrayWritableTransforms {
assertEquals(expColNames, tp.getFinalSchema().getColumnNames()); assertEquals(expColNames, tp.getFinalSchema().getColumnNames());
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0))); new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)));
List<Writable> out = tp.execute(in); List<Writable> out = tp.execute(in);
List<Writable> exp = List<Writable> exp =
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)),
new NDArrayWritable(Nd4j.linspace(0, 9, 10).addi(2.0))); new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE, 0, 10, 1).addi(2.0).reshape(1,10)));
assertEquals(exp, out); assertEquals(exp, out);
} }

View File

@ -20,9 +20,15 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
public class DataSetSplitterTests extends BaseDL4JTest { public class DataSetSplitterTests extends BaseDL4JTest {
@Test @Test
@ -144,4 +150,245 @@ public class DataSetSplitterTests extends BaseDL4JTest {
assertEquals(1000 * numEpochs, global); assertEquals(1000 * numEpochs, global);
} }
@Test
public void testSplitter_4() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, 1000, new double[]{0.5, 0.3, 0.2});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
int cnt = 0;
partIterator.reset();
while (partIterator.hasNext()) {
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000* numEpochs, global);
}
@Test
public void testSplitter_5() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{900, 100});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
partIterator.reset();
while (partIterator.hasNext()) {
int cnt = 0;
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000 * numEpochs, global);
}
@Test
public void testSplitter_6() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new DataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testUnorderedSplitter_1() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{500, 500});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
// Get data from second part, then rewind for the first one.
int cnt = 0;
int partNumber = 1;
while (iteratorList.get(partNumber).hasNext()) {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
cnt++;
global++;
}
iteratorList.get(partNumber).reset();
partNumber = 0;
cnt = 0;
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
global++;
}
}
}
@Test
public void testUnorderedSplitter_2() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{2});
List<DataSetIterator> iteratorList = splitter.getIterators();
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_3() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{10});
List<DataSetIterator> iteratorList = splitter.getIterators();
Random random = new Random();
int[] indexes = new int[iteratorList.size()];
for (int i = 0; i < indexes.length; ++i) {
indexes[i] = random.nextInt(iteratorList.size());
}
for (int partNumber : indexes) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_4() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new DataSetIteratorSplitter(back, new int[]{80, 10, 5});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0); // 0..79
val testIter = splitter.getIterators().get(1); // 80 ..89
val validationIter = splitter.getIterators().get(2); // 90..94
// we're skipping train/test and go for validation first. we're that crazy, right.
int valCnt = 0;
while (validationIter.hasNext()) {
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
valCnt++;
}
assertEquals(5, valCnt);
}
} }

View File

@ -18,11 +18,17 @@ package org.deeplearning4j.datasets.iterator;
import lombok.val; import lombok.val;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator; import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import static org.junit.Assert.assertEquals; import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
/** /**
* *
@ -150,4 +156,309 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertEquals(1000 * numEpochs, global); assertEquals(1000 * numEpochs, global);
} }
@Test
public void testMultiSplitter_1() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testSplitter_5() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{900, 100});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
partIterator.reset();
while (partIterator.hasNext()) {
int cnt = 0;
val data = partIterator.next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data[i].getFloat(0), 1e-5);
}
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000 * numEpochs, global);
}
@Test
public void testSplitter_6() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testUnorderedSplitter_1() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{500, 500});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
// Get data from second part, then rewind for the first one.
int cnt = 0;
int partNumber = 1;
while (iteratorList.get(partNumber).hasNext()) {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
}
cnt++;
global++;
}
iteratorList.get(partNumber).reset();
partNumber = 0;
cnt = 0;
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
data[i].getFloat(0), 1e-5);
}
global++;
}
}
}
@Test
public void testUnorderedSplitter_2() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{2});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
}
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_3() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
Random random = new Random();
int[] indexes = new int[iteratorList.size()];
for (int i = 0; i < indexes.length; ++i) {
indexes[i] = random.nextInt(iteratorList.size());
}
for (int partNumber : indexes) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
data[i].getFloat(0), 1e-5);
}
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_4() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{80, 10, 5});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0); // 0..79
val testIter = splitter.getIterators().get(1); // 80 ..89
val validationIter = splitter.getIterators().get(2); // 90..94
// we're skipping train/test and go for validation first. we're that crazy, right.
int valCnt = 0;
while (validationIter.hasNext()) {
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
ds.getFeatures()[i].getFloat(0), 1e-5);
}
valCnt++;
}
assertEquals(5, valCnt);
}
} }

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -196,4 +197,43 @@ public class TestRnnLayers extends BaseDL4JTest {
} }
} }
@Test
public void testMismatchedInputLabelLength(){
for( int i=0; i<2; i++ ){
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
switch (i){
case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
break;
case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
break;
default:
throw new RuntimeException();
}
MultiLayerConfiguration conf = lb.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
try{
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}
}
}
} }

View File

@ -249,7 +249,6 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
} }
@Test @Test
@Ignore("AB 2019/05/31 - Failing on CI and locally - see issues 7820 and 7657")
public void testCorrectness1() { public void testCorrectness1() {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE); DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123); Nd4j.getRandom().setSeed(123);
@ -270,30 +269,18 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
.useAdaGrad(false).build(); .useAdaGrad(false).build();
b.fit(data); b.fit(data);
System.out.println(b.getData());
/*double[] expectedData = new double[]{15.5392794313924, 19.25226403656672, -5.194955746137196, -31.787679714614757, 48.8674725273665, double[] expectedData = new double[]{ 63.8206, 80.4013, -19.4424, -140.4326, 198.7239,
24.92775755686273, -22.621939920239065, -29.790772278125395, 19.027362415188914, -16.013800175884274, 106.1148, -96.6273, -124.3634, 78.4174, -83.6621,
-27.454680593309185, 1.2929960811295493, -40.45000061571038, 61.23261682914338, 5.62278768938746, -121.8706, 3.0888, -172.8560, 255.1262, 20.7021,
-28.16665244970911, -20.05502814088798, 12.803274346870865, -24.877262522905497, 45.115883138175874, -120.7942, -78.1829, 56.6021, -112.3294, 185.4084,
21.597495694710616, 18.63254779638783, -4.029728632528419, -0.4596087279592638, -42.35340705500429, 88.5330, 78.0497, -18.8673, -11.0155, -175.1564,
-69.24727547461491, 40.94332685199673, -24.60866142208024, 17.689874972878723, -3.6779759693605314, -297.8463, 174.2511, -103.8793, 72.5455, -15.8498,
-30.91803590368529, 10.645452930824145, 36.58583235020565, -64.74975614289316, -39.364099390585956, -134.5235, 42.3300, 154.0391, -280.1010, -167.9765,
72.54886481127016, -35.30663155696714, 19.37116912936714, -7.790876543092118, 19.6586396288508, 306.9938, -150.9666, 83.4419, -36.0877, 83.9992,
58.1332709511154, -18.49217368496203, -3.5050200971182424, 5.662891294031322, 39.69533295638775, 245.1813, -81.5018, -14.8430, 16.1557, 166.8651,
-15.114610550011662, -32.42366951357609, 17.039297537056537, 42.25610885633673, -2.7013781552769904, -65.9247, -138.1783, 72.5444, 176.3088, -25.6732,
-16.338582630617925, 41.734027526336874, 20.941332646863426, -3.2145240561108244, -45.36033539684912};*/ -69.6843, 167.3360, 87.6238, -18.5874, -187.3806};
double[] expectedData = {40.93810899235225, 50.90183660191448, -14.298857560948981, -86.2012232604988, 129.51281793466023,
66.29136854264247, -61.650213611972326, -80.42836756633497, 50.28325210727952, -44.29008119040566,
-74.82748570869279, 2.0170536250746807, -109.21462846594635, 162.3973196127918, 14.000621153511705,
-76.30892822919527, -54.251704596942275, 33.99763310539589, -67.6307009607032, 119.50868525237786,
57.17786598853867, 49.1489174572297, -11.25663463504983, -2.38899196609398, -114.27194947404686,
-185.93832011474473, 108.9022579845252, -66.14099037301474, 47.13683038425694, -10.037893631405792,
-83.88458799629637, 26.985651418254996, 96.68139337135332, -174.2832443285551, -106.0999118697521,
193.02622700008175, -94.88003359113081, 51.39502524568139, -20.96021960048648, 52.32291574424741,
154.33973608321477, -50.90644802585217, -10.345744416395354, 13.721222143380892, 105.2111073677489,
-41.339268919407345, -87.73042354938127, 45.306865238870046, 112.53877133856602, -8.44454352074299,
-44.660828600669056, 110.72662022978719, 55.74660833987147, -9.613556053471232, -122.19953914048916};
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5); INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
for (int i = 0; i < expectedArray.rows(); ++i) for (int i = 0; i < expectedArray.rows(); ++i)

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.util;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -30,7 +31,7 @@ public class TimeSeriesUtilsTest extends BaseDL4JTest {
@Test @Test
public void testMovingAverage() { public void testMovingAverage() {
INDArray a = Nd4j.arange(0, 20); INDArray a = Nd4j.arange(0, 20).castTo(DataType.DOUBLE);
INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f,
12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f}); 12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f});

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -42,14 +43,20 @@ public class DataSetIteratorSplitter {
protected DataSetIterator backedIterator; protected DataSetIterator backedIterator;
protected final long totalExamples; protected final long totalExamples;
protected final double ratio; protected final double ratio;
protected final double[] ratios;
protected final long numTrain; protected final long numTrain;
protected final long numTest; protected final long numTest;
protected final long numArbitrarySets;
protected final int[] splits;
protected AtomicLong counter = new AtomicLong(0); protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false); protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null; protected DataSet firstTrain = null;
protected int partNumber = 0;
/** /**
* The only constructor * The only constructor
* *
@ -71,17 +78,94 @@ public class DataSetIteratorSplitter {
this.backedIterator = baseIterator; this.backedIterator = baseIterator;
this.totalExamples = totalBatches; this.totalExamples = totalBatches;
this.ratio = ratio; this.ratio = ratio;
this.ratios = null;
this.numTrain = (long) (totalExamples * ratio); this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain; this.numTest = totalExamples - numTrain;
this.numArbitrarySets = 2;
this.splits = null;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
} }
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double[] ratios) {
for (double ratio : ratios) {
if (!(ratio > 0.0 && ratio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
}
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.ratios = ratios;
this.numTrain = 0; //(long) (totalExamples * ratio);
this.numTest = 0; //totalExamples - numTrain;
this.numArbitrarySets = ratios.length;
this.splits = new int[this.ratios.length];
for (int i = 0; i < this.splits.length; ++i) {
this.splits[i] = (int)(totalExamples * ratios[i]);
}
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, int[] splits) {
/*if (!(simpleRatio > 0.0 && simpleRatio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");*/
int totalBatches = 0;
for (val v:splits)
totalBatches += v;
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.ratios = null;
this.numTrain = 0; //(long) (totalExamples * ratio);
this.numTest = 0; //totalExamples - numTrain;
this.splits = splits;
this.numArbitrarySets = splits.length;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public List<DataSetIterator> getIterators() {
List<DataSetIterator> retVal = new ArrayList<>();
int partN = 0;
int bottom = 0;
for (final int split : splits) {
ScrollableDataSetIterator partIterator =
new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain,
new int[]{bottom,split});
bottom += split;
retVal.add(partIterator);
}
return retVal;
}
/** /**
* This method returns train iterator instance * This method returns train iterator instance
* *
* @return * @return
*/ */
@Deprecated
public DataSetIterator getTrainIterator() { public DataSetIterator getTrainIterator() {
return new DataSetIterator() { return new DataSetIterator() {
@Override @Override
@ -184,6 +268,7 @@ public class DataSetIteratorSplitter {
* *
* @return * @return
*/ */
@Deprecated
public DataSetIterator getTestIterator() { public DataSetIterator getTestIterator() {
return new DataSetIterator() { return new DataSetIterator() {
@Override @Override

View File

@ -21,9 +21,12 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -43,6 +46,9 @@ public class MultiDataSetIteratorSplitter {
protected final double ratio; protected final double ratio;
protected final long numTrain; protected final long numTrain;
protected final long numTest; protected final long numTest;
protected final double[] ratios;
protected final long numArbitrarySets;
protected final int[] splits;
protected AtomicLong counter = new AtomicLong(0); protected AtomicLong counter = new AtomicLong(0);
@ -71,15 +77,87 @@ public class MultiDataSetIteratorSplitter {
this.ratio = ratio; this.ratio = ratio;
this.numTrain = (long) (totalExamples * ratio); this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain; this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = 0;
this.splits = null;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!"); log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
} }
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
for (double ratio : ratios) {
if (!(ratio > 0.0 && ratio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
}
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = ratios.length;
this.splits = new int[this.ratios.length];
for (int i = 0; i < this.splits.length; ++i) {
this.splits[i] = (int)(totalExamples * ratios[i]);
}
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) {
int totalBatches = 0;
for (val v:splits)
totalBatches += v;
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = splits.length;
this.splits = splits;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public List<MultiDataSetIterator> getIterators() {
List<MultiDataSetIterator> retVal = new ArrayList<>();
int partN = 0;
int bottom = 0;
for (final int split : splits) {
ScrollableMultiDataSetIterator partIterator =
new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain,
new int[]{bottom,split});
bottom += split;
retVal.add(partIterator);
}
return retVal;
}
/** /**
* This method returns train iterator instance * This method returns train iterator instance
* *
* @return * @return
*/ */
@Deprecated
public MultiDataSetIterator getTrainIterator() { public MultiDataSetIterator getTrainIterator() {
return new MultiDataSetIterator() { return new MultiDataSetIterator() {
@Override @Override
@ -162,6 +240,7 @@ public class MultiDataSetIteratorSplitter {
* *
* @return * @return
*/ */
@Deprecated
public MultiDataSetIterator getTestIterator() { public MultiDataSetIterator getTestIterator() {
return new MultiDataSetIterator() { return new MultiDataSetIterator() {
@Override @Override

View File

@ -0,0 +1,158 @@
package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
public class ScrollableDataSetIterator implements DataSetIterator {
private int thisPart = 0;
private int top = 0;
private int bottom = 0;
protected DataSetIterator backedIterator;
protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null;
protected MultiDataSet firstMultiTrain = null;
private double ratio;
private long totalExamples;
private long itemsPerPart;
private long current;
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
AtomicBoolean resetPending, DataSet firstTrain, double ratio,
int totalExamples) {
this.thisPart = num;
this.backedIterator = backedIterator;
this.counter = counter;
this.resetPending = resetPending;
this.firstTrain = firstTrain;
this.ratio = ratio;
this.totalExamples = totalExamples;
this.itemsPerPart = (long)(totalExamples * ratio);
this.current = 0;
}
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
AtomicBoolean resetPending, DataSet firstTrain,
int[] itemsPerPart) {
this.thisPart = num;
this.bottom = itemsPerPart[0];
this.top = bottom + itemsPerPart[1];
this.itemsPerPart = top;
this.backedIterator = backedIterator;
this.counter = counter;
//this.resetPending = resetPending;
this.firstTrain = firstTrain;
//this.totalExamples = totalExamples;
this.current = 0;
}
@Override
public DataSet next(int i) {
throw new UnsupportedOperationException();
}
@Override
public List<String> getLabels() {
return backedIterator.getLabels();
}
@Override
public int inputColumns() {
return backedIterator.inputColumns();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public int totalOutcomes() {
return backedIterator.totalOutcomes();
}
@Override
public boolean resetSupported() {
return backedIterator.resetSupported();
}
@Override
public boolean asyncSupported() {
return backedIterator.asyncSupported();
}
@Override
public void reset() {
resetPending.set(true);
}
@Override
public int batch() {
return backedIterator.batch();
}
@Override
public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
backedIterator.setPreProcessor(dataSetPreProcessor);
}
@Override
public DataSetPreProcessor getPreProcessor() {
return backedIterator.getPreProcessor();
}
@Override
public boolean hasNext() {
if (resetPending.get()) {
if (resetSupported()) {
backedIterator.reset();
counter.set(0);
current = 0;
resetPending.set(false);
} else
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
}
boolean state = false;
if (current >= top)
return false;
state = backedIterator.hasNext();
if (!state)
return false;
if (state && counter.get() < itemsPerPart)
return true;
else
return false;
}
@Override
public DataSet next() {
counter.incrementAndGet();
if ((current == 0) && (bottom != 0)) {
backedIterator.reset();
long cnt = current;
for (; cnt < bottom; ++cnt) {
if (backedIterator.hasNext())
backedIterator.next();
}
current = cnt+1;
}
else current++;
val p = backedIterator.next();
return p;
}
}

View File

@ -0,0 +1,121 @@
package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import javax.naming.OperationNotSupportedException;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
private int thisPart = 0;
private int top = 0;
private int bottom = 0;
protected MultiDataSetIterator backedIterator;
protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null;
protected MultiDataSet firstMultiTrain = null;
private double ratio;
private long totalExamples;
private long itemsPerPart;
private long current;
public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter,
MultiDataSet firstTrain, int[] itemsPerPart) {
this.thisPart = num;
this.bottom = itemsPerPart[0];
this.top = bottom + itemsPerPart[1];
this.itemsPerPart = top;
this.counter = counter;
//this.resetPending = resetPending;
this.firstTrain = null;
this.firstMultiTrain = firstTrain;
//this.totalExamples = totalExamples;
this.current = 0;
this.backedIterator = backedIterator;
this.resetPending = resetPending;
}
@Override
public boolean resetSupported() {
return backedIterator.resetSupported();
}
@Override
public boolean asyncSupported() {
return backedIterator.asyncSupported();
}
@Override
public void reset() {
resetPending.set(true);
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor dataSetPreProcessor) {
backedIterator.setPreProcessor(dataSetPreProcessor);
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException();
}
@Override
public boolean hasNext() {
if (resetPending.get()) {
if (resetSupported()) {
backedIterator.reset();
counter.set(0);
current = 0;
resetPending.set(false);
} else
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
}
boolean state = false;
if (current >= top)
return false;
state = backedIterator.hasNext();
if (!state)
return false;
if (state && counter.get() < itemsPerPart)
return true;
else
return false;
}
@Override
public MultiDataSet next() {
counter.incrementAndGet();
if ((current == 0) && (bottom != 0)) {
backedIterator.reset();
long cnt = current;
for (; cnt < bottom; ++cnt) {
if (backedIterator.hasNext())
backedIterator.next();
}
current = cnt+1;
}
else current++;
val p = backedIterator.next();
return p;
}
@Override
public MultiDataSet next(int i) {
throw new UnsupportedOperationException();
}
}

View File

@ -47,6 +47,8 @@ import static org.bytedeco.hdf5.global.hdf5.*;
@Slf4j @Slf4j
public class Hdf5Archive implements Closeable { public class Hdf5Archive implements Closeable {
public static final int MAX_BUFFER_SIZE_BYTES = (int)Math.pow(2, 28); //256 MB
/** /**
* HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently * HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently
* in multiple threads. This object is used for locking read etc activity using synchronized blocks * in multiple threads. This object is used for locking read etc activity using synchronized blocks
@ -338,7 +340,7 @@ public class Hdf5Archive implements Closeable {
private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException { private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException {
synchronized (Hdf5Archive.LOCK_OBJECT) { synchronized (Hdf5Archive.LOCK_OBJECT) {
VarLenType vl = attribute.getVarLenType(); VarLenType vl = attribute.getVarLenType();
int bufferSizeMult = 1; int currBufferLength = 2048;
String s; String s;
/* TODO: find a less hacky way to do this. /* TODO: find a less hacky way to do this.
* Reading variable length strings (from attributes) is a giant * Reading variable length strings (from attributes) is a giant
@ -349,8 +351,8 @@ public class Hdf5Archive implements Closeable {
* buffer and repeat. * buffer and repeat.
*/ */
while (true) { while (true) {
byte[] attrBuffer = new byte[bufferSizeMult * 2000]; byte[] attrBuffer = new byte[currBufferLength];
BytePointer attrPointer = new BytePointer(attrBuffer); BytePointer attrPointer = new BytePointer(currBufferLength);
attribute.read(vl, attrPointer); attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer); attrPointer.get(attrBuffer);
s = new String(attrBuffer); s = new String(attrBuffer);
@ -362,9 +364,11 @@ public class Hdf5Archive implements Closeable {
} catch (IOException e) { } catch (IOException e) {
//OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer //OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer
} }
bufferSizeMult *= 2;
if (bufferSizeMult > 1024) { if(currBufferLength == MAX_BUFFER_SIZE_BYTES){
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute"); throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute: size exceeds " + currBufferLength + " bytes");
} else {
currBufferLength = (int)Math.min(MAX_BUFFER_SIZE_BYTES, currBufferLength * 4L);
} }
} }
vl.deallocate(); vl.deallocate();

View File

@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.deeplearning4j.clustering.cluster.Cluster; import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet; import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.ClusterUtils; import org.deeplearning4j.clustering.cluster.ClusterUtils;
@ -62,12 +63,13 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
private ClusterSet clusterSet; private ClusterSet clusterSet;
private List<Point> initialPoints; private List<Point> initialPoints;
private transient ExecutorService exec; private transient ExecutorService exec;
private boolean useKmeansPlusPlus;
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy) {
this.clusteringStrategy = clusteringStrategy; this.clusteringStrategy = clusteringStrategy;
this.exec = MultiThreadUtils.newExecutorService(); this.exec = MultiThreadUtils.newExecutorService();
this.useKmeansPlusPlus = useKmeansPlusPlus;
} }
/** /**
@ -75,8 +77,8 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
* @param clusteringStrategy * @param clusteringStrategy
* @return * @return
*/ */
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy) { public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
return new BaseClusteringAlgorithm(clusteringStrategy); return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
} }
/** /**
@ -86,7 +88,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
*/ */
public ClusterSet applyTo(List<Point> points) { public ClusterSet applyTo(List<Point> points) {
resetState(points); resetState(points);
initClusters(); initClusters(useKmeansPlusPlus);
iterations(); iterations();
return clusterSet; return clusterSet;
} }
@ -130,7 +132,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
* Initialize the * Initialize the
* cluster centers at random * cluster centers at random
*/ */
protected void initClusters() { protected void initClusters(boolean kMeansPlusPlus) {
log.info("Generating initial clusters"); log.info("Generating initial clusters");
List<Point> points = new ArrayList<>(initialPoints); List<Point> points = new ArrayList<>(initialPoints);
@ -152,7 +154,10 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) { while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec); dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
double r = random.nextFloat() * dxs.maxNumber().doubleValue(); double summed = Nd4j.sum(dxs).getDouble(0);
double r = kMeansPlusPlus ? random.nextDouble() * summed:
random.nextFloat() * dxs.maxNumber().doubleValue();
for (int i = 0; i < dxs.length(); i++) { for (int i = 0; i < dxs.length(); i++) {
double distance = dxs.getDouble(i); double distance = dxs.getDouble(i);
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " + Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
@ -170,6 +175,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
new IterationInfo(currentIteration, initialClusterSetInfo)); new IterationInfo(currentIteration, initialClusterSetInfo));
} }
protected void applyClusteringStrategy() { protected void applyClusteringStrategy() {
if (!isStrategyApplicableNow()) if (!isStrategyApplicableNow())
return; return;

View File

@ -79,8 +79,8 @@ public class ClusterUtils {
int nClusters = clusterSet.getClusterCount(); int nClusters = clusterSet.getClusterCount();
for (int i = 0; i < nClusters; i++) { for (int i = 0; i < nClusters; i++) {
final Cluster cluster = clusterSet.getClusters().get(i); final Cluster cluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() { //tasks.add(new Runnable() {
public void run() { // public void run() {
try { try {
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
refreshClusterCenter(cluster, clusterInfo); refreshClusterCenter(cluster, clusterInfo);
@ -88,10 +88,10 @@ public class ClusterUtils {
} catch (Throwable t) { } catch (Throwable t) {
log.warn("Error refreshing cluster centers", t); log.warn("Error refreshing cluster centers", t);
} }
// }
//});
} }
}); //MultiThreadUtils.parallelTasks(tasks, executorService);
}
MultiThreadUtils.parallelTasks(tasks, executorService);
} }
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
@ -146,28 +146,29 @@ public class ClusterUtils {
List<Runnable> tasks = new ArrayList<>(); List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < pointsCount; i++) { for (int i = 0; i < pointsCount; i++) {
final int i2 = i; final int i2 = i;
tasks.add(new Runnable() { //tasks.add(new Runnable() {
public void run() { // public void run() {
try { try {
Point point = points.get(i2); Point point = points.get(i2);
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
: Math.pow(newCluster.getDistanceToCenter(point), 2); : Math.pow(newCluster.getDistanceToCenter(point), 2);
dxs.putScalar(i2, clusterSet.isInverse() ? dist : dist); dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
} catch (Throwable t) { } catch (Throwable t) {
log.warn("Error computing squared distance from nearest cluster", t); log.warn("Error computing squared distance from nearest cluster", t);
} }
} // }
}); //});
} }
MultiThreadUtils.parallelTasks(tasks, executorService); //MultiThreadUtils.parallelTasks(tasks, executorService);
for (int i = 0; i < pointsCount; i++) { for (int i = 0; i < pointsCount; i++) {
double previousMinDistance = previousDxs.getDouble(i); double previousMinDistance = previousDxs.getDouble(i);
if (clusterSet.isInverse()) { if (clusterSet.isInverse()) {
if (dxs.getDouble(i) < previousMinDistance) if (dxs.getDouble(i) < previousMinDistance) {
dxs.putScalar(i, previousMinDistance); dxs.putScalar(i, previousMinDistance);
}
} else if (dxs.getDouble(i) > previousMinDistance) } else if (dxs.getDouble(i) > previousMinDistance)
dxs.putScalar(i, previousMinDistance); dxs.putScalar(i, previousMinDistance);
} }
@ -175,6 +176,23 @@ public class ClusterUtils {
return dxs; return dxs;
} }
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
final List<Point> points, INDArray previousDxs) {
final int pointsCount = points.size();
final INDArray dxs = Nd4j.create(pointsCount);
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
Double sum = new Double(0);
for (int i = 0; i < pointsCount; i++) {
Point point = points.get(i);
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
sum += dist;
dxs.putScalar(i, sum);
}
return dxs;
}
/** /**
* *
* @param clusterSet * @param clusterSet
@ -194,27 +212,27 @@ public class ClusterUtils {
List<Runnable> tasks = new ArrayList<>(); List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) { for (int i = 0; i < clusterCount; i++) {
final Cluster cluster = clusterSet.getClusters().get(i); final Cluster cluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() { //tasks.add(new Runnable() {
public void run() { // public void run() {
try { try {
info.getClustersInfos().put(cluster.getId(), info.getClustersInfos().put(cluster.getId(),
computeClusterInfos(cluster, clusterSet.getDistanceFunction())); computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
} catch (Throwable t) { } catch (Throwable t) {
log.warn("Error computing cluster set info", t); log.warn("Error computing cluster set info", t);
} }
} //}
}); //});
} }
MultiThreadUtils.parallelTasks(tasks, executorService); //MultiThreadUtils.parallelTasks(tasks, executorService);
tasks = new ArrayList<>(); //tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) { for (int i = 0; i < clusterCount; i++) {
final int clusterIdx = i; final int clusterIdx = i;
final Cluster fromCluster = clusterSet.getClusters().get(i); final Cluster fromCluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() { //tasks.add(new Runnable() {
public void run() { //public void run() {
try { try {
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
Cluster toCluster = clusterSet.getClusters().get(k); Cluster toCluster = clusterSet.getClusters().get(k);
@ -230,12 +248,12 @@ public class ClusterUtils {
} catch (Throwable t) { } catch (Throwable t) {
log.warn("Error computing distances", t); log.warn("Error computing distances", t);
} }
} // }
}); //});
} }
MultiThreadUtils.parallelTasks(tasks, executorService); //MultiThreadUtils.parallelTasks(tasks, executorService);
return info; return info;
} }

View File

@ -37,8 +37,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* *
* @param clusteringStrategy * @param clusteringStrategy
*/ */
protected KMeansClustering(ClusteringStrategy clusteringStrategy) { protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
super(clusteringStrategy); super(clusteringStrategy, useKMeansPlusPlus);
} }
/** /**
@ -50,11 +50,11 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return * @return
*/ */
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
boolean inverse) { boolean inverse, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = ClusteringStrategy clusteringStrategy =
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse); FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount); clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
return new KMeansClustering(clusteringStrategy); return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
} }
/** /**
@ -66,10 +66,10 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return * @return
*/ */
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean inverse, boolean allowEmptyClusters) { boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse) ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); .endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy); return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
} }
@ -81,8 +81,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @param distanceFunction the distance function to use for grouping * @param distanceFunction the distance function to use for grouping
* @return * @return
*/ */
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction) { public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
return setup(clusterCount, maxIterationCount, distanceFunction, false); return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
} }
/** /**
@ -94,17 +94,17 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return * @return
*/ */
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean allowEmptyClusters) { boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy); return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
} }
public static KMeansClustering setup(int clusterCount, Distance distanceFunction, public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
boolean allowEmptyClusters) { boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE); clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
return new KMeansClustering(clusteringStrategy); return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
} }
} }

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.clustering.kmeans; package org.deeplearning4j.clustering.kmeans;
import lombok.val;
import org.apache.commons.lang3.time.StopWatch; import org.apache.commons.lang3.time.StopWatch;
import org.deeplearning4j.clustering.BaseDL4JTest; import org.deeplearning4j.clustering.BaseDL4JTest;
import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.clustering.algorithm.Distance;
@ -28,36 +29,40 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.junit.Assert.fail;
/** /**
* Created by agibsonccc on 7/2/17. * Created by agibsonccc on 7/2/17.
*/ */
public class KMeansTest extends BaseDL4JTest { public class KMeansTest extends BaseDL4JTest {
private boolean[] useKMeansPlusPlus = {true, false};
@Test @Test
public void testKMeans() { public void testKMeans() {
Nd4j.getRandom().setSeed(7); Nd4j.getRandom().setSeed(7);
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN); for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode);
List<Point> points = Point.toPoints(Nd4j.randn(5, 5)); List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
ClusterSet clusterSet = kMeansClustering.applyTo(points); ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
System.out.println(pointClassification); System.out.println(pointClassification);
} }
}
@Test @Test
public void testKmeansCosine() { public void testKmeansCosine() {
Nd4j.getRandom().setSeed(7); Nd4j.getRandom().setSeed(7);
int numClusters = 5; int numClusters = 5;
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true); for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
List<Point> points = Point.toPoints(Nd4j.rand(5, 300)); List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points); ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN); KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points); ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0)); PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
System.out.println("Cosine " + pointClassification); System.out.println("Cosine " + pointClassification);
@ -66,6 +71,7 @@ public class KMeansTest extends BaseDL4JTest {
assertEquals(pointClassification.getCluster().getPoints().get(0), assertEquals(pointClassification.getCluster().getPoints().get(0),
pointClassificationEuclidean.getCluster().getPoints().get(0)); pointClassificationEuclidean.getCluster().getPoints().get(0));
} }
}
@Ignore @Ignore
@Test @Test
@ -73,9 +79,10 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7); Nd4j.getRandom().setSeed(7);
int numClusters = 20; int numClusters = 20;
for (boolean mode : useKMeansPlusPlus) {
StopWatch watch = new StopWatch(); StopWatch watch = new StopWatch();
watch.start(); watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true); KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300)); List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points); ClusterSet clusterSet = kMeansClustering.applyTo(points);
@ -90,6 +97,7 @@ public class KMeansTest extends BaseDL4JTest {
watch.stop(); watch.stop();
System.out.println("Elapsed for search: " + watch); System.out.println("Elapsed for search: " + watch);
} }
}
@Test @Test
@Ignore @Ignore
@ -97,9 +105,10 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7); Nd4j.getRandom().setSeed(7);
int numClusters = 20; int numClusters = 20;
for (boolean mode : useKMeansPlusPlus) {
StopWatch watch = new StopWatch(); StopWatch watch = new StopWatch();
watch.start(); watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false); KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode);
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
@ -117,7 +126,7 @@ public class KMeansTest extends BaseDL4JTest {
watch.reset(); watch.reset();
watch.start(); watch.start();
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false); kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode);
points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
@ -133,6 +142,7 @@ public class KMeansTest extends BaseDL4JTest {
watch.stop(); watch.stop();
System.out.println("Elapsed for search: " + watch); System.out.println("Elapsed for search: " + watch);
} }
}
@Test @Test
public void testCorrectness() { public void testCorrectness() {
@ -141,7 +151,8 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7); Nd4j.getRandom().setSeed(7);
int numClusters = 3; int numClusters = 3;
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, true); for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
double[] data = new double[]{ double[] data = new double[]{
15, 16, 15, 16,
16, 18.5, 16, 18.5,
@ -181,6 +192,7 @@ public class KMeansTest extends BaseDL4JTest {
for (int i = 0; i < clusters.size(); ++i) for (int i = 0; i < clusters.size(); ++i)
System.out.println("Choice: " + clusters.get(i).getCenter().getArray()); System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
} }
}
/*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}), /*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}),
pointClassification.getCluster().getCenter().getArray());*/ pointClassification.getCluster().getCenter().getArray());*/
@ -233,4 +245,39 @@ public class KMeansTest extends BaseDL4JTest {
System.out.println(); System.out.println();
} }
} }
@Test
public void testInitClusters() {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7);
{
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true);
double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8},
{8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8},
{1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8},
{2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8},
{3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8},
{4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8},
{4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8},
{5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8},
{6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}};
INDArray data = Nd4j.createFromArray(dataArray);
List<Point> points = Point.toPoints(data);
ClusterSet clusterSet = kMeansClustering.applyTo(points);
double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8};
double[] centroid2 = {5.14e8, 5.41e8, 5.68e8, 5.95e8};
double[] centroid3 = {1.63e8, 1.9e8, 2.17e8, 2.44e8};
double[] centroid4 = {6.76e8, 7.03e8, 7.3e8, 7.57e8};
double[] centroid5 = {4.06e8, 4.33e8, 4.6e8, 4.87e8};
assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4);
}
}
} }

View File

@ -23,6 +23,8 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
import org.junit.Rule; import org.junit.Rule;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
@ -857,4 +859,34 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
} }
} }
@Test
public void testBackwardsCompatibleWord2Vec() {
File model_v3 = Resources.asFile("deeplearning4j-nlp/model_beta3.zip");
File model_v4 = Resources.asFile("deeplearning4j-nlp/model_beta4.zip");
Word2Vec word2Vec1 = WordVectorSerializer.readWord2VecModel(model_v3, true);
Word2Vec word2Vec2 = WordVectorSerializer.readWord2VecModel(model_v4, true);
try {
assertEquals(word2Vec1.toJson(), word2Vec2.toJson());
} catch (Exception e) {
fail(e.getMessage());
}
}
@Test
public void testBackwardsCompatibleSequenceVectors() {
File model_v3 = Resources.asFile("deeplearning4j-nlp/seqv_beta3.csv");
File model_v4 = Resources.asFile("deeplearning4j-nlp/seqv_beta4.csv");
try {
SequenceVectors vectors1 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v3);
SequenceVectors vectors2 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v4);
assertEquals(vectors1.vocab().numWords(), vectors2.vocab().numWords());
for (int i = 0; i < vectors1.vocab().numWords(); ++i) {
assertEquals(vectors1.vocab().words().toArray()[i], vectors2.vocab().words().toArray()[i]);
}
} catch (Exception e) {
fail(e.getMessage());
}
}
} }

View File

@ -249,7 +249,7 @@ public class BertIterator implements MultiDataSetIterator {
} else { } else {
throw new RuntimeException(); throw new RuntimeException();
} }
l[0] = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, numClasses); l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for( int i=0; i<mb; i++ ){ for( int i=0; i<mb; i++ ){
l[0].putScalar(i, classLabels[i], 1.0); l[0].putScalar(i, classLabels[i], 1.0);
} }
@ -277,9 +277,9 @@ public class BertIterator implements MultiDataSetIterator {
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){ if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength); labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){ } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, vocabSize, outLength); labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){ } else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), outLength, mbPadded, vocabSize); labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else { } else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat); throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
} }

View File

@ -201,7 +201,7 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
List<String> tokens = new ArrayList<>(); List<String> tokens = new ArrayList<>();
while (t.hasMoreTokens()) { while (t.hasMoreTokens()) {
String token = t.nextToken(); String token = t.nextToken();
if (!wordVectors.hasWord(token)) { if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) {
switch (unknownWordHandling) { switch (unknownWordHandling) {
case RemoveWord: case RemoveWord:
continue; continue;

View File

@ -1312,10 +1312,12 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
int rest = batchSequences.size() % batchSize; int rest = batchSequences.size() % batchSize;
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0); int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
for (int j = 0; j < chunks; ++j) { for (int j = 0; j < chunks; ++j) {
if (trainElementsVectors) {
if (elementsLearningAlgorithm instanceof SkipGram) if (elementsLearningAlgorithm instanceof SkipGram)
((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j)); ((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
else if (elementsLearningAlgorithm instanceof CBOW) else if (elementsLearningAlgorithm instanceof CBOW)
((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j)); ((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
}
if (trainSequenceVectors) { if (trainSequenceVectors) {
if (sequenceLearningAlgorithm instanceof DBOW) if (sequenceLearningAlgorithm instanceof DBOW)

View File

@ -32,7 +32,7 @@ import java.io.Serializable;
* *
* @author Adam Gibson * @author Adam Gibson
*/ */
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = VocabWord.class)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE, @JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
setterVisibility = JsonAutoDetect.Visibility.NONE) setterVisibility = JsonAutoDetect.Visibility.NONE)
public class VocabWord extends SequenceElement implements Serializable { public class VocabWord extends SequenceElement implements Serializable {

View File

@ -224,6 +224,7 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L) @Test(timeout = 20000L)
public void testMinibatchPadding() throws Exception { public void testMinibatchPadding() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope."; String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.api; package org.deeplearning4j.nn.api;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@ -73,4 +74,6 @@ public interface TrainingConfig {
*/ */
double getGradientNormalizationThreshold(); double getGradientNormalizationThreshold();
void setDataType(DataType dataType);
} }

View File

@ -93,4 +93,9 @@ public abstract class GraphVertex implements Cloneable, Serializable {
*/ */
public abstract MemoryReport getMemoryReport(InputType... inputTypes); public abstract MemoryReport getMemoryReport(InputType... inputTypes);
public void setDataType(DataType dataType) {
//No-op for most layers
}
} }

View File

@ -146,4 +146,9 @@ public class LayerVertex extends GraphVertex {
//TODO preprocessor memory //TODO preprocessor memory
return layerConf.getLayer().getMemoryReport(it); return layerConf.getLayer().getMemoryReport(it);
} }
@Override
public void setDataType(DataType dataType){
layerConf.getLayer().setDataType(dataType);
}
} }

View File

@ -223,6 +223,11 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
"Not supported: all layers with parameters should override this method"); "Not supported: all layers with parameters should override this method");
} }
@Override
public void setDataType(DataType dataType) {
//No-op for most layers
}
/** /**
* This is a report of the estimated memory consumption for the given layer * This is a report of the estimated memory consumption for the given layer
* *

View File

@ -96,7 +96,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
if (!map.containsKey(inputNum)) { if (!map.containsKey(inputNum)) {
//Lazily define extra input variable as required //Lazily define extra input variable as required
SDVariable var = sameDiff.var("var_" + inputNum, 1); //TODO is this shape safe? SDVariable var = sameDiff.var("var_" + inputNum, dataType, -1); //TODO is this shape safe?
map.put(inputNum, var); map.put(inputNum, var);
} }

View File

@ -62,6 +62,7 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
protected IUpdater biasUpdater; protected IUpdater biasUpdater;
protected GradientNormalization gradientNormalization; protected GradientNormalization gradientNormalization;
protected double gradientNormalizationThreshold = Double.NaN; protected double gradientNormalizationThreshold = Double.NaN;
protected DataType dataType;
/** /**
* Define the vertex * Define the vertex
@ -234,4 +235,9 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
public double getGradientNormalizationThreshold() { public double getGradientNormalizationThreshold() {
return gradientNormalizationThreshold; return gradientNormalizationThreshold;
} }
@Override
public void setDataType(DataType dataType) {
this.dataType = dataType;
}
} }

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.misc;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@ -63,4 +64,9 @@ public class DummyConfig implements TrainingConfig {
public double getGradientNormalizationThreshold() { public double getGradientNormalizationThreshold() {
return 1.0; return 1.0;
} }
@Override
public void setDataType(DataType dataType) {
}
} }

View File

@ -512,6 +512,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
for(; i<topologicalOrder.length; i++ ){ for(; i<topologicalOrder.length; i++ ){
String name = indices.getIdxToName().get(i); String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name); org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
n.setDataType(netDtype);
numParamsForVertex[i] = n.numParams(true); numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i]; numParams += numParamsForVertex[i];
} }

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.TimeSeriesUtils; import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -35,6 +36,7 @@ import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.Arrays;
import java.util.List; import java.util.List;
/** /**
@ -60,10 +62,16 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
assertInputSet(true); assertInputSet(true);
if (input.rank() != 3) if (input.rank() != 3)
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Input is not rank 3. Got input with rank " + input.rank() + " " + layerId()); "Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId());
if (labels == null) if (labels == null)
throw new IllegalStateException("Labels are not set (null)"); throw new IllegalStateException("Labels are not set (null)");
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM); INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM); INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM);
INDArray maskReshaped; INDArray maskReshaped;

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.util.TimeSeriesUtils; import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -57,8 +58,13 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." + "Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId()); " Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
} }
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
INDArray inputTemp = input; INDArray inputTemp = input;
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM); this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout
this.input = inputTemp; this.input = inputTemp;
INDArray epsilon2d = gradAndEpsilonNext.getSecond(); INDArray epsilon2d = gradAndEpsilonNext.getSecond();

View File

@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays; import java.util.*;
import java.util.LinkedHashMap;
import java.util.Map;
/** /**
* Implementation of a SameDiff graph vertex. * Implementation of a SameDiff graph vertex.
@ -96,12 +94,11 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
@Override @Override
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if(sameDiff == null){
doInit(); doInit();
} }
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
// sameDiff.clearExecutionCache();
config.validateInput(inputs); config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){ for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i); String name = config.getVertexParams().getInputs().get(i);
@ -121,6 +118,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
} }
Map<String,INDArray> out = sameDiff.exec(null, outputKey); Map<String,INDArray> out = sameDiff.exec(null, outputKey);
INDArray result = out.get(outputKey); INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} }
} }
@ -131,27 +132,42 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
INDArray[] dLdIns; INDArray[] dLdIns;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
// sameDiff.clearExecutionCache(); if(sameDiff == null){
doInit();
}
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
List<String> inputs = config.getVertexParams().getInputs();
String[] inArr = inputs.toArray(new String[inputs.size()]);
sameDiff.createGradFunction(inArr);
}
config.validateInput(inputs); config.validateInput(inputs);
//Set inputs Map<String,INDArray> phMap = new HashMap<>();
for(int i=0; i<inputs.length; i++ ){ List<String> inputs = config.getVertexParams().getInputs();
String name = config.getVertexParams().getInputs().get(i); int i=0;
for(String s : inputs){
phMap.put(s, this.inputs[i++]);
}
if(maskArrays != null){
for( int j=0; j<maskArrays.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask"; final String maskName = name + "_mask";
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name)); if(maskArrays[j] != null) {
if(maskArrays != null && maskArrays[i] != null) { sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
}else{
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
} }
} }
fn.updateVariable(outputVar.getVarName(), epsilon.dup()); }
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once! //TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s); sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
sameDiff.execBackwards(null); sameDiff.execBackwards(phMap);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
@ -159,10 +175,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
} }
dLdIns = new INDArray[inputs.length]; dLdIns = new INDArray[inputs.size()];
for(int i=0; i<inputs.length; i++ ){ for(int j=0; j<inputs.size(); j++ ){
String name = config.getVertexParams().getInputs().get(i); String name = inputs.get(j);
dLdIns[i] = sameDiff.grad(name).getArr(); dLdIns[j] = sameDiff.grad(name).getArr();
} }
} }
@ -171,6 +187,9 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]); dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return new Pair<>(g, dLdIns); return new Pair<>(g, dLdIns);
} }

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*; import java.util.*;
@ -78,25 +79,32 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
@Override @Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false); assertInputSet(false);
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if(sameDiff == null){
doInit(); doInit();
} }
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(maskArray != null){ if(maskArray != null){
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY)); phMap.put(MASK_KEY, maskArray);
}else{
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
} }
for(String s : paramTable.keySet() ) { for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s); sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
Map<String,INDArray> out = sameDiff.exec(null, outputKey); Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
INDArray result = out.get(outputKey); INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} }
} }
@ -110,24 +118,36 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
INDArray dLdIn; INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
// sameDiff.clearExecutionCache(); if(sameDiff == null){
doInit();
}
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY);
}
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
if(maskArray != null){
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
}else{
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
}
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once! //TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s); sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap()); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
}
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
for(String s : paramTable.keySet()){
requiredGrads.add(sameDiff.grad(s).getVarName());
}
sameDiff.execBackwards(phMap, requiredGrads);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
@ -138,6 +158,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
dLdIn = sameDiff.grad(INPUT_KEY).getArr(); dLdIn = sameDiff.grad(INPUT_KEY).getArr();
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
System.out.println(dLdIn);
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
} }
@ -225,8 +250,9 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff = SameDiff.create(); sameDiff = SameDiff.create();
Map<String, INDArray> p = paramTable(); Map<String, INDArray> p = paramTable();
val inputShape = input.shape().clone(); long[] inputShape = input.shape().clone();
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape); inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes(); Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
Map<String, SDVariable> params = new LinkedHashMap<>(); Map<String, SDVariable> params = new LinkedHashMap<>();
for (String s : paramShapes.keySet()) { for (String s : paramShapes.keySet()) {
@ -235,7 +261,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
params.put(s, v); params.put(s, v);
} }
SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(dataType, inputShape)); long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, -1);
SDVariable mask = sameDiff.placeHolder(MASK_KEY, dataType, maskShape);
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask); SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask);
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");

View File

@ -87,35 +87,43 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){ private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){
assertInputSet(false); assertInputSet(false);
//Check where the output occors. If it's a simple loss layer (no params) this could //Check where the output occurs. If it's a simple loss layer (no params) this could
// just be the input! // just be the input!
if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){ if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
} }
//TODO optimize
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if(sameDiff == null){
doInit(); doInit();
} }
//TODO optimize
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
if(layerConf().labelsRequired() && labels != null) {
sameDiff.associateArrayWithVariable(labels.dup(), sameDiff.getVariable(LABELS_KEY));
}
for(String s : paramTable.keySet() ) { for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s); sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
INDArray score = sameDiff.execAndEndResult(); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) {
phMap.put(LABELS_KEY, labels);
}
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
INDArray out = sameDiff.execSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
if(activations) { if(activations) {
INDArray result = sameDiff.getArrForVarName(layerConf().activationsVertexName()); Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
Preconditions.checkNotNull(result, "Activations (result) array for variable \"%s\" was " +
"null - error during execution or this variable (as defined by method activationsVertexName()) " + "null - error during execution or this variable (as defined by method activationsVertexName()) " +
"does not exist", layerConf().activationsVertexName()); "does not exist", layerConf().activationsVertexName());
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
} else { } else {
return score; return out;
} }
} }
} }
@ -127,23 +135,26 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " + Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " +
"If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" + "If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" +
" to return false instead"); " to return false instead");
Gradient g = new DefaultGradient();
INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
if(sameDiff == null){ if(sameDiff == null){
//Usually doInit will be called in forward pass; not necessarily the case in output layers //Usually doInit will be called in forward pass; not necessarily the case in output layers
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph) // (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
doInit(); doInit();
} }
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY);
}
Gradient g = new DefaultGradient(); INDArray castInput = input.castTo(dataType);
INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
INDArray castInput = input.castTo(Nd4j.defaultFloatingPointType());
if(castInput.isAttached()) if(castInput.isAttached())
castInput = castInput.dup(); castInput = castInput.dup();
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY)); sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
if(layerConf().labelsRequired()) { if(layerConf().labelsRequired()) {
INDArray castLabels = labels.castTo(Nd4j.defaultFloatingPointType()); INDArray castLabels = labels.castTo(dataType);
if(castLabels.isAttached()) if(castLabels.isAttached())
castLabels = castLabels.dup(); castLabels = castLabels.dup();
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY)); sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
@ -154,7 +165,17 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.associateArrayWithVariable(paramTable.get(s), s); sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap()); List<String> gradVarNames = new ArrayList<>();
for(String s : paramTable.keySet()){
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
}
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels);
sameDiff.execBackwards(phMap, gradVarNames);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
@ -165,6 +186,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
dLdIn = sameDiff.grad(INPUT_KEY).getArr(); dLdIn = sameDiff.grad(INPUT_KEY).getArr();
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
} }
@ -252,18 +277,20 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff = SameDiff.create(); sameDiff = SameDiff.create();
Map<String, INDArray> p = paramTable(); Map<String, INDArray> p = paramTable();
val inputShape = input.shape().clone(); long[] inputShape = input.shape().clone();
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape); inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
SDVariable labelVar = null; SDVariable labelVar = null;
if(layerConf().labelsRequired()){ if(layerConf().labelsRequired()){
long[] labelShape = labels == null ? new long[]{1} : labels.shape().clone(); long[] labelShape = labels == null ? new long[]{-1, -1} : labels.shape().clone();
labelVar = sameDiff.var(LABELS_KEY, dataType, labelShape); labelShape[0] = -1;
labelVar = sameDiff.placeHolder(LABELS_KEY, dataType, labelShape);
} }
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes(); Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
Map<String, SDVariable> params = new LinkedHashMap<>(); Map<String, SDVariable> params = new LinkedHashMap<>();
for (String s : paramShapes.keySet()) { for (String s : paramShapes.keySet()) {
val ps = paramShapes.get(s); val ps = paramShapes.get(s);
SDVariable v = sameDiff.var(s, ps); SDVariable v = sameDiff.var(s, dataType, ps);
params.put(s, v); params.put(s, v);
} }
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params); SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params);

View File

@ -660,6 +660,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
val nParamsPerLayer = new long[nLayers]; val nParamsPerLayer = new long[nLayers];
for (int i = 0; i < nLayers; i++) { for (int i = 0; i < nLayers; i++) {
NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i);
conf.getLayer().setDataType(netDtype);
nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
paramLength += nParamsPerLayer[i]; paramLength += nParamsPerLayer[i];
} }

View File

@ -152,7 +152,7 @@ public class HardwareMetric implements Serializable {
return builder.logicalProcessorCount(processor.getLogicalProcessorCount()) return builder.logicalProcessorCount(processor.getLogicalProcessorCount())
.physicalProcessorCount(processor.getPhysicalProcessorCount()) .physicalProcessorCount(processor.getPhysicalProcessorCount())
.name(name) .name(name)
.averagedCpuLoad((long) processor.getSystemCpuLoad() * 100) .averagedCpuLoad((long)(processor.getSystemCpuLoad() * 100))
.ioWaitTime(iowait).gpuMetrics(gpuMetric) .ioWaitTime(iowait).gpuMetrics(gpuMetric)
.hostName(networkParams.getHostName()).diskInfo(diskInfoMap) .hostName(networkParams.getHostName()).diskInfo(diskInfoMap)
.currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable()) .currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable())

View File

@ -48,8 +48,6 @@ if(WIN32)
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "") SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
endif() endif()
if ("${LIBND4J_ALL_OPS}") if ("${LIBND4J_ALL_OPS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
else() else()
@ -234,21 +232,21 @@ if(CUDA_BLAS)
endif() endif()
endif() endif()
if (NOT BUILD_TESTS)
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h) file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/*.cpp ../include/execution/*.h) file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h) file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h) file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h) file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h) file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp) file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp) file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu ../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h) file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/cuda/*.cu ../include/helpers/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu) file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
@ -258,20 +256,6 @@ if(CUDA_BLAS)
else() else()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA} CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h ../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
@ -308,7 +292,7 @@ elseif(CPU_BLAS)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h) file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h) file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp) file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp) file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h) file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h) file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h) file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)

View File

@ -372,8 +372,8 @@ namespace nd4j {
/** /**
* if _bufferD==nullptr return _buffer, else return _bufferD * if _bufferD==nullptr return _buffer, else return _bufferD
*/ */
FORCEINLINE void* specialBuffer(); void* specialBuffer();
FORCEINLINE void* getSpecialBuffer() const; void* getSpecialBuffer() const;
/** /**
* returns device buffer if compilation is for cuda case, otherwise returns host buffer * returns device buffer if compilation is for cuda case, otherwise returns host buffer
@ -429,16 +429,16 @@ namespace nd4j {
/** /**
* permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array * permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array
*/ */
NDArray* permute(const std::initializer_list<int>& dimensions) const; NDArray permute(const std::initializer_list<int>& dimensions) const;
NDArray* permute(const std::vector<int>& dimensions) const; NDArray permute(const std::vector<int>& dimensions) const;
NDArray* permute(const int* dimensions, const int rank) const; NDArray permute(const int* dimensions, const int rank) const;
void permute(const int* dimensions, const int rank, NDArray& target) const; void permute(const int* dimensions, const int rank, NDArray& target) const;
void permute(const std::vector<int>& dimensions, NDArray& target) const; void permute(const std::vector<int>& dimensions, NDArray& target) const;
NDArray* permute(const std::initializer_list<Nd4jLong>& dimensions) const; NDArray permute(const std::initializer_list<Nd4jLong>& dimensions) const;
NDArray* permute(const std::vector<Nd4jLong>& dimensions) const; NDArray permute(const std::vector<Nd4jLong>& dimensions) const;
NDArray* permute(const Nd4jLong* dimensions, const int rank) const; NDArray permute(const Nd4jLong* dimensions, const int rank) const;
void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const; void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const;
void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const; void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const;
@ -508,7 +508,7 @@ namespace nd4j {
/** /**
* returns new copy of this array, optionally in different order * returns new copy of this array, optionally in different order
*/ */
NDArray *dup(const char newOrder = 'a'); NDArray *dup(const char newOrder = 'a') const;
/** /**
* returns sum of all elements of array * returns sum of all elements of array
@ -687,7 +687,7 @@ namespace nd4j {
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
#if defined(__CUDABLAS__) && defined(BUILD_TESTS) #if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
template <typename Lambda> template <typename Lambda>
FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr); FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr);
@ -790,8 +790,7 @@ namespace nd4j {
/** /**
* apply transpose operation to the copy of this array, that is this array remains unaffected * apply transpose operation to the copy of this array, that is this array remains unaffected
*/ */
NDArray* transpose() const; NDArray transpose() const;
NDArray transp() const;
/** /**
* perform transpose operation and store result in target, this array remains unaffected * perform transpose operation and store result in target, this array remains unaffected
@ -915,7 +914,7 @@ namespace nd4j {
* *
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array * if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/ */
NDArray* reshape(const char order, const std::vector<Nd4jLong>& shape) const; NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const;
/** /**
* calculate strides and set given order * calculate strides and set given order
@ -2093,15 +2092,6 @@ Nd4jLong* NDArray::shapeInfo() {
return _shapeInfo; return _shapeInfo;
} }
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::specialShapeInfo() { Nd4jLong* NDArray::specialShapeInfo() {
if (_shapeInfoD == nullptr) if (_shapeInfoD == nullptr)
@ -2110,14 +2100,6 @@ Nd4jLong* NDArray::specialShapeInfo() {
return _shapeInfoD; return _shapeInfoD;
} }
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::getBufferOffset() const { Nd4jLong NDArray::getBufferOffset() const {
return _offset; return _offset;
@ -2137,7 +2119,7 @@ Nd4jLong* NDArray::getSpecialShapeInfo() const{
} }
#if defined(__CUDACC__) && defined(BUILD_TESTS) #if defined(__CUDACC__) //&& defined(BUILD_TESTS)
// for CUDA we need stil stuff inline // for CUDA we need stil stuff inline
#include "cuda/NDArrayLambda.hpp" #include "cuda/NDArrayLambda.hpp"
#endif #endif

View File

@ -39,9 +39,9 @@ NDArray* NDArray::asT() const{
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext()); auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
auto l = this->lengthOf(); auto l = this->lengthOf();
prepareSpecialUse({result}, {this}); NDArray::prepareSpecialUse({result}, {this});
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({result}, {this}); NDArray::registerSpecialUse({result}, {this});
return result; return result;
} }
@ -583,117 +583,130 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop
void NDArray::assign(const double value) { void NDArray::assign(const double value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const float value) { void NDArray::assign(const float value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const float16 value) { void NDArray::assign(const float16 value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext()); auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const bfloat16& value) { void NDArray::assign(const bfloat16& value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext()); auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const Nd4jLong value) { void NDArray::assign(const Nd4jLong value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext()); auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int value) { void NDArray::assign(const int value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int16_t value) { void NDArray::assign(const int16_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint8_t value) { void NDArray::assign(const uint8_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint16_t value) { void NDArray::assign(const uint16_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint32_t value) { void NDArray::assign(const uint32_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint64_t value) { void NDArray::assign(const uint64_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int8_t value) { void NDArray::assign(const int8_t value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::assign(const bool value) { void NDArray::assign(const bool value) {
// just fire scalar // just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext()); auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp}); NDArray::registerSpecialUse({this}, {&temp});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -716,9 +729,9 @@ NDArray NDArray::varianceNumber(nd4j::variance::Ops op, bool biasCorrected) {
NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext());
prepareSpecialUse({&res}, {this}); NDArray::prepareSpecialUse({&res}, {this});
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected); NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected);
registerSpecialUse({&res}, {this}); NDArray::registerSpecialUse({&res}, {this});
return res; return res;
} }
@ -918,9 +931,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::FloatOps op, void *extraParams) cons
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType())); auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
NDArray result(shape, true, this->getContext()); NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -932,9 +945,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::SameOps op, void *extraParams) const
NDArray result(dataType(), getContext()); NDArray result(dataType(), getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -947,9 +960,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::BoolOps op, void *extraParams) const
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL); auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
NDArray result(shape, true, this->getContext()); NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -962,9 +975,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64); auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
NDArray result(shape, true, this->getContext()); NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -976,9 +989,9 @@ void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *ext
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
prepareSpecialUse({&target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -989,9 +1002,9 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != dataType()) if(!target.isScalar() || target.dataType() != dataType())
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
prepareSpecialUse({&target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo()); NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1002,9 +1015,9 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != DataType::BOOL) if(!target.isScalar() || target.dataType() != DataType::BOOL)
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
prepareSpecialUse({&target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo()); NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1015,9 +1028,9 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != DataType::INT64) if(!target.isScalar() || target.dataType() != DataType::INT64)
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
prepareSpecialUse({&target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo()); NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -1027,9 +1040,9 @@ NDArray NDArray::indexReduceNumber(nd4j::indexreduce::Ops op, ExtraArguments *ex
auto res = NDArrayFactory::create<Nd4jLong>(0); auto res = NDArrayFactory::create<Nd4jLong>(0);
NDArray::prepareSpecialUse({&res}, {this}); NDArray::NDArray::prepareSpecialUse({&res}, {this});
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo());
NDArray::registerSpecialUse({&res}, {this}); NDArray::NDArray::registerSpecialUse({&res}, {this});
return res; return res;
} }
@ -1240,17 +1253,10 @@ BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4j
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
NDArray* NDArray::transpose() const { NDArray NDArray::transpose() const {
auto newArr = new NDArray(getBuffer(), getSpecialBuffer(), getShapeInfo(), getContext(), false, false); NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr->transposei();
return newArr;
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::transp() const {
NDArray newArr(getBuffer(), getShapeInfo(), getContext(), false);
newArr.transposei(); newArr.transposei();
return newArr; return newArr;
} }
@ -1360,10 +1366,10 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array // create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray* NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const { NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
auto newArr = new NDArray(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext()); NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr->reshapei(order, shape); newArr.reshapei(order, shape);
return newArr; return newArr;
} }
@ -1420,43 +1426,43 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const int* dimensions, const int rank) const { NDArray NDArray::permute(const int* dimensions, const int rank) const {
// evaluate shapeInfo for output (permuted) array ret // evaluate shapeInfo for output (permuted) array ret
auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace());
auto ret = new NDArray(_buffer, ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset()); NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
ret->_isView = true; ret._isView = true;
return ret; return ret;
} }
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const Nd4jLong* dimensions, const int rank) const { NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
int tempDims[MAX_RANK]; int tempDims[MAX_RANK];
shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank); shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank);
return permute(tempDims, rank); return permute(tempDims, rank);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::vector<int>& dimensions) const { NDArray NDArray::permute(const std::vector<int>& dimensions) const {
auto data = dimensions.data(); auto data = dimensions.data();
auto size = dimensions.size(); auto size = dimensions.size();
return permute(data, size); return permute(data, size);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::vector<Nd4jLong>& dimensions) const { NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
return permute(dimensions.data(), dimensions.size()); return permute(dimensions.data(), dimensions.size());
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::initializer_list<int>& dimensions) const { NDArray NDArray::permute(const std::initializer_list<int>& dimensions) const {
std::vector<int> vec(dimensions); std::vector<int> vec(dimensions);
return permute(vec); return permute(vec);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const { NDArray NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
std::vector<Nd4jLong> vec(dimensions); std::vector<Nd4jLong> vec(dimensions);
return permute(vec); return permute(vec);
} }
@ -1528,10 +1534,9 @@ bool NDArray::isUnitary() {
throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !"); throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !");
auto tr = this->transpose(); auto tr = this->transpose();
auto trMul = MmulHelper::mmul(this, tr, nullptr, 1.f, 0.f); auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f);
bool result = trMul->isIdentityMatrix(); bool result = trMul->isIdentityMatrix();
delete tr;
delete trMul; delete trMul;
return result; return result;
@ -1777,11 +1782,11 @@ NDArray NDArray::operator*(const T& scalar) const {
auto tmp = NDArrayFactory::create(dataType(), scalar, getContext()); auto tmp = NDArrayFactory::create(dataType(), scalar, getContext());
NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext()); NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext());
NDArray::prepareSpecialUse({&result}, {this, &tmp}); NDArray::prepareSpecialUse({&result}, {this, &tmp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &tmp}); NDArray::registerSpecialUse({&result}, {this, &tmp});
return result; return result;
} }
template NDArray NDArray::operator*(const double& scalar) const; template NDArray NDArray::operator*(const double& scalar) const;
@ -1811,6 +1816,7 @@ NDArray NDArray::operator/(const T& scalar) const {
NDArray::prepareSpecialUse({&result}, {this, &tmp}); NDArray::prepareSpecialUse({&result}, {this, &tmp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &tmp}); NDArray::registerSpecialUse({&result}, {this, &tmp});
return result; return result;
} }
template NDArray NDArray::operator/(const double& scalar) const; template NDArray NDArray::operator/(const double& scalar) const;
@ -2050,14 +2056,14 @@ void NDArray::operator+=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) { if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else{ else{
Nd4jLong *bShape = nullptr; Nd4jLong *bShape = nullptr;
@ -2084,14 +2090,14 @@ void NDArray::operator-=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) { if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else{ else{
Nd4jLong *bShape = nullptr; Nd4jLong *bShape = nullptr;
@ -2117,14 +2123,14 @@ void NDArray::operator*=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) { if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else{ else{
Nd4jLong *bShape = nullptr; Nd4jLong *bShape = nullptr;
@ -2154,14 +2160,14 @@ void NDArray::operator/=(const NDArray& other) {
} }
if (!this->isScalar() && other.isScalar()) { if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other}); NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other}); NDArray::registerSpecialUse({this}, {this, &other});
} }
else{ else{
Nd4jLong *bShape = nullptr; Nd4jLong *bShape = nullptr;
@ -2264,9 +2270,9 @@ NDArray NDArray::operator-(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
prepareSpecialUse({&result}, {this, &other}); NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other}); NDArray::registerSpecialUse({&result}, {this, &other});
return result; return result;
} }
@ -2285,9 +2291,9 @@ NDArray NDArray::operator*(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext()); NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
prepareSpecialUse({&result}, {this, &other}); NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other}); NDArray::registerSpecialUse({&result}, {this, &other});
return result; return result;
} }
@ -2308,9 +2314,9 @@ NDArray NDArray::operator/(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
prepareSpecialUse({&result}, {this, &other}); NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other}); NDArray::registerSpecialUse({&result}, {this, &other});
return result; return result;
} }
@ -2326,9 +2332,9 @@ NDArray NDArray::operator-() const {
NDArray result(getShapeInfo(), false, getContext()); NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -2631,7 +2637,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& di
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
NDArray::prepareSpecialUse({result}, {this, other}); NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
registerSpecialUse({result}, {this, other}); NDArray::registerSpecialUse({result}, {this, other});
return; return;
} }
@ -2688,7 +2694,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
NDArray::prepareSpecialUse({result}, {this, other}); NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
registerSpecialUse({result}, {this, other}); NDArray::registerSpecialUse({result}, {this, other});
return; return;
} }
@ -2896,7 +2902,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew; Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(this->rankOf(), this->_shapeInfo, shape.size(), shape.data(), shapeInfoNew); bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides // we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) { if (canReshape) {
@ -2948,11 +2954,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* othe
if (target->dataType() != this->dataType() && target->dataType() != other->dataType()) if (target->dataType() != this->dataType() && target->dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !");
prepareSpecialUse({target}, {this, other}); NDArray::prepareSpecialUse({target}, {this, other});
NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
NDArray::registerSpecialUse({target}, {this, other});
registerSpecialUse({target}, {this, other});
if (extraParams != nullptr) if (extraParams != nullptr)
synchronize("NDArray::applyPairwiseTransform"); synchronize("NDArray::applyPairwiseTransform");
@ -2969,9 +2973,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
if (dataType() != other->dataType()) if (dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
prepareSpecialUse({target}, {this, other}); NDArray::prepareSpecialUse({target}, {this, other});
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
registerSpecialUse({target}, {this, other}); NDArray::registerSpecialUse({target}, {this, other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -3070,22 +3074,23 @@ void NDArray::assign(const NDArray& other) {
if (other.isScalar()) { if (other.isScalar()) {
if(this->isScalar()) { if(this->isScalar()) {
preparePrimaryUse({this}, {&other}); NDArray::preparePrimaryUse({this}, {&other});
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
registerPrimaryUse({this}, {&other}); NDArray::registerPrimaryUse({this}, {&other});
this->syncToDevice();
} }
else { else {
if (dataType() != other.dataType()) { if (dataType() != other.dataType()) {
auto tmp = other.cast(dataType()); auto tmp = other.cast(dataType());
prepareSpecialUse({this}, {tmp}); NDArray::prepareSpecialUse({this}, {tmp});
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {}); NDArray::registerSpecialUse({this}, {});
delete tmp; delete tmp;
} }
else { else {
prepareSpecialUse({this}, {&other}); NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&other}); NDArray::registerSpecialUse({this}, {&other});
} }
} }
} }
@ -3101,16 +3106,16 @@ void NDArray::assign(const NDArray& other) {
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else { else {
prepareSpecialUse({this}, {&other}); NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({this}, {&other}); NDArray::registerSpecialUse({this}, {&other});
} }
} }
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// This method returns new copy of this NDArray, optionally in different order // This method returns new copy of this NDArray, optionally in different order
NDArray* NDArray::dup(const char newOrder) { NDArray* NDArray::dup(const char newOrder) const {
if (isEmpty()) if (isEmpty())
return NDArrayFactory::empty_(dataType(), getContext()); return NDArrayFactory::empty_(dataType(), getContext());
@ -3170,7 +3175,7 @@ std::string NDArray::e(const Nd4jLong i) const {
if (!isS()) if (!isS())
throw std::runtime_error("Can't get std::string out of non-string array"); throw std::runtime_error("Can't get std::string out of non-string array");
preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
// getting "virtual" offset. it's not real though,since it doesn't take lengths into account // getting "virtual" offset. it's not real though,since it doesn't take lengths into account
auto offset = getOffset(i); auto offset = getOffset(i);
@ -3208,8 +3213,8 @@ T NDArray::e(const Nd4jLong i) const {
const auto rp = getOffset(i); const auto rp = getOffset(i);
preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
} }
@ -3226,8 +3231,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
const Nd4jLong coords[2] = {i, j}; const Nd4jLong coords[2] = {i, j};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3246,8 +3251,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
const Nd4jLong coords[3] = {i, j, k}; const Nd4jLong coords[3] = {i, j, k};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3266,8 +3271,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
const Nd4jLong coords[4] = {i, j, k, l}; const Nd4jLong coords[4] = {i, j, k, l};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3300,9 +3305,9 @@ void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, Extr
if (!target->isR()) if (!target->isR())
throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types");
prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({target}, {this});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3314,9 +3319,9 @@ void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraA
if (target == nullptr) if (target == nullptr)
target = this; target = this;
prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({target}, {this});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3331,9 +3336,9 @@ void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, Extra
if (target->dataType() != dataType()) if (target->dataType() != dataType())
throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array");
prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({target}, {this});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3347,9 +3352,9 @@ void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, Ext
if (!this->isR() || !target->isR() || (this->dataType() != target->dataType())) if (!this->isR() || !target->isR() || (this->dataType() != target->dataType()))
throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !");
registerSpecialUse({target}, {this}); NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
prepareSpecialUse({target}, {this}); NDArray::registerSpecialUse({target}, {this});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3363,9 +3368,9 @@ void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, Extra
if (!target->isB()) if (!target->isB())
throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types");
prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr); NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({target}, {this});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3375,9 +3380,9 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons
NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext()); NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext());
registerSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
prepareSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -3389,9 +3394,9 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const
NDArray result(getShapeInfo(), false, getContext()); NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -3403,9 +3408,9 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con
NDArray result(getShapeInfo(), false, getContext()); NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -3417,9 +3422,9 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext()); NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext());
prepareSpecialUse({&result}, {this}); NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this}); NDArray::registerSpecialUse({&result}, {this});
return result; return result;
} }
@ -3435,9 +3440,9 @@ void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArra
if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType())) if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType()))
throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!");
prepareSpecialUse({target}, {this, scalar}); NDArray::prepareSpecialUse({target}, {this, scalar});
NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
registerSpecialUse({target}, {this, scalar}); NDArray::registerSpecialUse({target}, {this, scalar});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3471,10 +3476,9 @@ void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, ND
throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!");
} }
prepareSpecialUse({target}, {this, scalar}); NDArray::prepareSpecialUse({target}, {this, scalar});
NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
NDArray::registerSpecialUse({target}, {this, scalar});
registerSpecialUse({target}, {this, scalar});
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -3557,7 +3561,7 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons
NDArray::prepareSpecialUse({result}, {this, other}); NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo()); NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo());
registerSpecialUse({result}, {this, other}); NDArray::registerSpecialUse({result}, {this, other});
return result; return result;
} }
@ -3635,9 +3639,9 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
prepareSpecialUse({result}, {this, other}); NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
registerSpecialUse({result}, {this, other}); NDArray::registerSpecialUse({result}, {this, other});
return result; return result;
} }
@ -3780,9 +3784,9 @@ void NDArray::p(const Nd4jLong i, const T value) {
auto rp = getOffset(i); auto rp = getOffset(i);
const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value)); const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value));
preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES);
registerPrimaryUse({this}, {}); NDArray::registerPrimaryUse({this}, {});
} }
template void NDArray::p(const Nd4jLong i, const double value); template void NDArray::p(const Nd4jLong i, const double value);
@ -3811,9 +3815,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {}); NDArray::registerPrimaryUse({this}, {});
} }
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
@ -3837,13 +3841,13 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
void *p = reinterpret_cast<void *>(const_cast<T *>(&value)); void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {}); NDArray::registerPrimaryUse({this}, {});
} }
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
@ -3870,9 +3874,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
Nd4jLong coords[4] = {i, j, k, l}; Nd4jLong coords[4] = {i, j, k, l};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {}); NDArray::registerPrimaryUse({this}, {});
} }
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
@ -3896,10 +3900,10 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
if (i >= _length) if (i >= _length)
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
preparePrimaryUse({this}, {&scalar}, true); NDArray::preparePrimaryUse({this}, {&scalar}, true);
auto rp = getOffset(i); auto rp = getOffset(i);
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
registerPrimaryUse({this}, {&scalar}); NDArray::registerPrimaryUse({this}, {&scalar});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -4195,7 +4199,7 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size()); auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
auto numTads = lengthOf() / shape::length(pack.primaryShapeInfo()); auto numTads = pack.numberOfTads();
for (int idx = 0; idx < numTads; idx++ ) { for (int idx = 0; idx < numTads; idx++ ) {
auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset()); auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset());

View File

@ -1578,6 +1578,20 @@ public:
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong *dxShapeInfo,
bool descending); bool descending);
void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending);
void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending);
void sortTad(Nd4jPointer *extraPointers, void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo, void *dx, Nd4jLong *dxShapeInfo,
@ -1587,6 +1601,24 @@ public:
Nd4jLong *tadOffsets, Nd4jLong *tadOffsets,
bool descending); bool descending);
void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
// special sort impl for sorting out COO indices and values // special sort impl for sorting out COO indices and values
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);

View File

@ -208,6 +208,23 @@ void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return nullptr; return nullptr;
} }
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// change an array by repeating it the number of times given by reps. // change an array by repeating it the number of times given by reps.
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const { NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {

View File

@ -27,6 +27,52 @@
namespace nd4j { namespace nd4j {
////////////////////////////////////////////////////////////////////////
template <>
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
bool* hostBuffer = nullptr;
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
std::copy(data.begin(), data.end(), hostBuffer);
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) { NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
std::string s(str); std::string s(str);
@ -227,10 +273,13 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned int> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned long> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint16_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context); template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context);
@ -391,6 +440,7 @@ template NDArray NDArrayFactory::create(const std::vector<bfloat16> &values, nd4
template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<uint16_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context);
@ -452,53 +502,6 @@ template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::L
return new NDArray(order, shape, dataType, context); return new NDArray(order, shape, dataType, context);
} }
////////////////////////////////////////////////////////////////////////
template <typename T>
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <>
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
bool* hostBuffer = nullptr;
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
std::copy(data.begin(), data.end(), hostBuffer);
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) { NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) {

View File

@ -2736,6 +2736,60 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
return reinterpret_cast<Nd4jPointer>(shapeBuffer); return reinterpret_cast<Nd4jPointer>(shapeBuffer);
} }
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);

View File

@ -192,8 +192,8 @@ void NDArray::setIdentity() {
if (isS()) if (isS())
throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!"); throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!");
if (rankOf() != 2) // if (rankOf() != 2)
throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given."); // throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
const int threadsPerBlock = MAX_NUM_THREADS / 4; const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
@ -234,22 +234,27 @@ void NDArray::synchronize(const char* msg) const {
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList) for (const auto& a : readList)
if(a != nullptr)
a->syncToDevice(); a->syncToDevice();
for (const auto& a : writeList) { for (const auto& a : writeList) {
if (a != nullptr) {
a->getDataBuffer()->allocateSpecial(); a->getDataBuffer()->allocateSpecial();
if (synchronizeWritables) if (synchronizeWritables)
a->syncToDevice(); a->syncToDevice();
} }
} }
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
for (const auto& p : readList) for (const auto& p : readList)
if(p != nullptr)
p->tickReadDevice(); p->tickReadDevice();
for (const auto& p : writeList) for (const auto& p : writeList)
if (p != nullptr)
p->tickWriteDevice(); p->tickWriteDevice();
} }
@ -257,22 +262,27 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) { void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList) for (const auto& a : readList)
if(a != nullptr)
a->syncToHost(); a->syncToHost();
for (const auto& a : writeList) { for (const auto& a : writeList) {
if (a != nullptr) {
a->getDataBuffer()->allocatePrimary(); a->getDataBuffer()->allocatePrimary();
if (synchronizeWritables) if (synchronizeWritables)
a->syncToHost(); a->syncToHost();
} }
} }
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) { void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
for (const auto& p : readList) for (const auto& p : readList)
if(p != nullptr)
p->tickReadHost(); p->tickReadHost();
for (const auto& p : writeList) for (const auto& p : writeList)
if (p != nullptr)
p->tickWriteHost(); p->tickWriteHost();
} }
@ -427,9 +437,26 @@ void NDArray::repeat(int dimension, NDArray& target) const {
NDArray::registerSpecialUse({&target}, {this}); NDArray::registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {\ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {
if(_length == 0) if(_length == 0)
{ printf("NDArray::printActualBuffer: array length is zero !\n"); return; } { printf("NDArray::printActualBuffer: array length is zero !\n"); return; }
@ -477,7 +504,7 @@ template void NDArray::printCurrentBuffer<double>(const bool host, const char* m
#if defined(__CUDACC__) && !defined(BUILD_TESTS) #if defined(__CUDACC__) && !defined(BUILD_TESTS)
#include <cpu/NDArrayLambda.hpp> //#include <cpu/NDArrayLambda.hpp>
#endif #endif

View File

@ -2321,6 +2321,163 @@ void NativeOps::sort(Nd4jPointer *extraPointers,
} }
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
dim3 launchDims(numBlocks, numThreads, 32768);
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
dim3 launchDims(numBlocks, numThreads, 32768);
int max = 2, dg = 0;
while (max < xLength) {
max <<= 1;
dg++;
}
max <<= 1;
for (int window = 2; window < max; window<<=1) {
int n = window;
int rev = 0;
do{
int half = n >> 1;
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
}
}
}
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
dim3 launchDims(numBlocks, numThreads, 32768);
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
dim3 launchDims(numBlocks, numThreads, 32768);
int max = 2, dg = 0;
while (max < xLength) {
max <<= 1;
dg++;
}
max <<= 1;
for (int window = 2; window < max; window<<=1) {
int n = window;
int rev = 0;
do{
int half = n >> 1;
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
}
}
}
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed");
}
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed");
}
void NativeOps::sortTad(Nd4jPointer *extraPointers, void NativeOps::sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -2331,15 +2488,13 @@ void NativeOps::sortTad(Nd4jPointer *extraPointers,
bool descending) { bool descending) {
// to be implemented // to be implemented
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768);
dim3 launchDims(tadPack.numberOfTads(), 1024, 33768);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadFloat(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed");
} }
void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) { void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {

View File

@ -38,11 +38,11 @@ namespace nd4j {
ConstantDataBuffer() = default; ConstantDataBuffer() = default;
~ConstantDataBuffer() = default; ~ConstantDataBuffer() = default;
Nd4jLong sizeOf(); Nd4jLong sizeOf() const;
Nd4jLong length(); Nd4jLong length() const;
Nd4jPointer primary(); Nd4jPointer primary() const;
Nd4jPointer special(); Nd4jPointer special() const;
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default; ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default; ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;

View File

@ -261,6 +261,8 @@ DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
allocateBuffers(); allocateBuffers();
copyBufferFrom(other); copyBufferFrom(other);
return *this;
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -285,6 +287,8 @@ DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
other._primaryBuffer = other._specialBuffer = nullptr; other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false); other.setAllocFlags(false, false);
other._lenInBytes = 0; other._lenInBytes = 0;
return *this;
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -335,6 +335,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("INT8"); return std::string("INT8");
case INT16: case INT16:
return std::string("INT16"); return std::string("INT16");
case UINT16:
return std::string("UINT16");
case INT32: case INT32:
return std::string("INT32"); return std::string("INT32");
case INT64: case INT64:
@ -375,7 +377,7 @@ FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo,
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
// returns the difference between 1.0 and the next representable value of the given floating-point type // returns the difference between 1.0 and the next representable value of the given floating-point type
template <typename T> template <typename T>
FORCEINLINE T DataTypeUtils::eps() { FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
if (std::is_same<T, double>::value) if (std::is_same<T, double>::value)
return std::numeric_limits<double>::epsilon(); return std::numeric_limits<double>::epsilon();
else if (std::is_same<T, float>::value) else if (std::is_same<T, float>::value)

View File

@ -26,6 +26,7 @@
#include <vector> #include <vector>
#include <array/DataType.h> #include <array/DataType.h>
#include <pointercast.h> #include <pointercast.h>
#include <stdlib.h>
namespace nd4j { namespace nd4j {
class ND4J_EXPORT ExtraArguments { class ND4J_EXPORT ExtraArguments {

View File

@ -35,21 +35,21 @@ namespace nd4j {
TadPack() = default; TadPack() = default;
~TadPack() = default; ~TadPack() = default;
Nd4jLong* primaryShapeInfo(); Nd4jLong* primaryShapeInfo() const;
Nd4jLong* primaryOffsets(); Nd4jLong* primaryOffsets() const;
Nd4jLong* specialShapeInfo(); Nd4jLong* specialShapeInfo() const;
Nd4jLong* specialOffsets(); Nd4jLong* specialOffsets() const;
Nd4jLong numberOfTads(); Nd4jLong numberOfTads() const;
int shapeInfoLength(); int shapeInfoLength() const;
/** /**
* These methods return either primary or special pointers depending on platform binaries were compiled for * These methods return either primary or special pointers depending on platform binaries were compiled for
* @return * @return
*/ */
Nd4jLong *platformShapeInfo(); Nd4jLong *platformShapeInfo() const;
Nd4jLong *platformOffsets(); Nd4jLong *platformOffsets() const;
}; };
} }

View File

@ -28,19 +28,19 @@ namespace nd4j {
_sizeOf = sizeOf; _sizeOf = sizeOf;
} }
Nd4jPointer ConstantDataBuffer::primary() { Nd4jPointer ConstantDataBuffer::primary() const {
return _primaryBuffer; return _primaryBuffer;
} }
Nd4jPointer ConstantDataBuffer::special() { Nd4jPointer ConstantDataBuffer::special() const {
return _specialBuffer; return _specialBuffer;
} }
Nd4jLong ConstantDataBuffer::sizeOf() { Nd4jLong ConstantDataBuffer::sizeOf() const {
return _sizeOf; return _sizeOf;
} }
Nd4jLong ConstantDataBuffer::length() { Nd4jLong ConstantDataBuffer::length() const {
return _length; return _length;
} }

View File

@ -54,7 +54,7 @@ namespace nd4j {
NDArray* NDArrayList::readRaw(int idx) { NDArray* NDArrayList::readRaw(int idx) {
if (_chunks.count(idx) < 1) { if (_chunks.count(idx) < 1) {
nd4j_printf("Non-existent chunk requested: [%i]\n", idx); nd4j_printf("Non-existent chunk requested: [%i]\n", idx);
throw std::runtime_error("Bad index"); throw std::invalid_argument("Bad index");
} }
return _chunks[idx]; return _chunks[idx];
@ -120,7 +120,7 @@ namespace nd4j {
// storing reference // storing reference
_chunks[idx] = array; _chunks[idx] = array;
return ND4J_STATUS_OK; return Status::OK();
} }
std::vector<Nd4jLong>& NDArrayList::shape() { std::vector<Nd4jLong>& NDArrayList::shape() {
@ -152,8 +152,10 @@ namespace nd4j {
std::vector<bool> bargs; std::vector<bool> bargs;
int numElements = _elements.load(); int numElements = _elements.load();
for (int e = 0; e < numElements; e++) for (int e = 0; e < numElements; e++) {
_chunks[e]->syncToDevice();
inputs.emplace_back(_chunks[e]); inputs.emplace_back(_chunks[e]);
}
iargs.push_back(_axis); iargs.push_back(_axis);

View File

@ -29,34 +29,34 @@ namespace nd4j {
_numTads = numTads; _numTads = numTads;
} }
Nd4jLong* TadPack::primaryShapeInfo() { Nd4jLong* TadPack::primaryShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.primary()); return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
} }
Nd4jLong* TadPack::primaryOffsets() { Nd4jLong* TadPack::primaryOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary()); return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
} }
Nd4jLong* TadPack::specialShapeInfo() { Nd4jLong* TadPack::specialShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.special()); return reinterpret_cast<Nd4jLong *>(_tadShape.special());
} }
Nd4jLong* TadPack::specialOffsets() { Nd4jLong* TadPack::specialOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special()); return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
} }
Nd4jLong TadPack::numberOfTads() { Nd4jLong TadPack::numberOfTads() const {
return _numTads; return _numTads;
} }
Nd4jLong* TadPack::platformShapeInfo() { Nd4jLong* TadPack::platformShapeInfo() const {
return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo(); return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
} }
Nd4jLong* TadPack::platformOffsets() { Nd4jLong* TadPack::platformOffsets() const {
return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets(); return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
} }
int TadPack::shapeInfoLength() { int TadPack::shapeInfoLength() const {
return (int) shape::shapeInfoLength(primaryShapeInfo()); return (int) shape::shapeInfoLength(primaryShapeInfo());
} }
} }

View File

@ -27,7 +27,7 @@ namespace nd4j {
class AttentionHelper { class AttentionHelper {
public: public:
static nd4j::NDArray* multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static nd4j::NDArray multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
}; };
} }

View File

@ -69,10 +69,10 @@ namespace nd4j {
} }
void executeOnce() override { void executeOnce() override {
auto xT = (_tA ? _x->transpose() : _x); auto xT = (_tA ? _x->transpose() : *_x);
auto yT = (_tB ? _y->transpose() : _y); auto yT = (_tB ? _y->transpose() : *_y);
MmulHelper::mmul(xT, yT, _z, _alpha, _beta); MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta);
} }
std::string axis() override { std::string axis() override {

View File

@ -133,9 +133,9 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
// if(matrix.rankOf() != 2) // if(matrix.rankOf() != 2)
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !"; // throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
if(matrix.sizeAt(0) == 1) if(matrix.sizeAt(0) == 1) {
matrix *= (T) 1.f - coeff; matrix *= (T) 1.f - coeff;
}
else if(coeff != (T)0.f) { else if(coeff != (T)0.f) {
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true)); auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
@ -145,13 +145,11 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
auto column = tail; auto column = tail;
auto row = tail.transpose(); auto row = tail.transpose();
auto resultingRow = mmul(*row, bottomPartCopy); auto resultingRow = mmul(row, bottomPartCopy);
auto fistRow = matrix({0,1, 0,0}, true); auto fistRow = matrix({0,1, 0,0}, true);
resultingRow += fistRow; resultingRow += fistRow;
fistRow -= resultingRow * coeff; fistRow -= resultingRow * coeff;
*bottomPart -= mmul(column, resultingRow) * coeff; *bottomPart -= mmul(column, resultingRow) * coeff;
delete row;
} }
else { else {
@ -161,9 +159,7 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
auto fistRow = matrix({0,1, 0,0}, true); auto fistRow = matrix({0,1, 0,0}, true);
resultingRow += fistRow; resultingRow += fistRow;
fistRow -= resultingRow * coeff; fistRow -= resultingRow * coeff;
*bottomPart -= mmul(*column, resultingRow) * coeff; *bottomPart -= mmul(column, resultingRow) * coeff;
delete column;
} }
delete bottomPart; delete bottomPart;
} }
@ -193,21 +189,16 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
auto resultingCol = mmul(rightPartCopy, column); auto resultingCol = mmul(rightPartCopy, column);
resultingCol += *fistCol; resultingCol += *fistCol;
*fistCol -= resultingCol * coeff; *fistCol -= resultingCol * coeff;
*rightPart -= mmul(resultingCol, *row) * coeff; *rightPart -= mmul(resultingCol, row) * coeff;
delete row;
} }
else { else {
auto row = tail; auto row = tail;
auto column = tail.transpose(); auto column = tail.transpose();
auto resultingCol = mmul(rightPartCopy, *column); auto resultingCol = mmul(rightPartCopy, column);
resultingCol += *fistCol; resultingCol += *fistCol;
*fistCol -= resultingCol * coeff; *fistCol -= resultingCol * coeff;
*rightPart -= mmul(resultingCol, row) * coeff; *rightPart -= mmul(resultingCol, row) * coeff;
delete column;
} }
delete rightPart; delete rightPart;
delete fistCol; delete fistCol;

View File

@ -157,8 +157,7 @@ bool JacobiSVD<T>::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) {
if(_calcU) { if(_calcU) {
auto temp2 = rotation.transpose(); auto temp2 = rotation.transpose();
mulRotationOnRight(p, q, _u, *temp2); mulRotationOnRight(p, q, _u, temp2);
delete temp2;
} }
} }
@ -251,9 +250,7 @@ void JacobiSVD<T>::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDA
m.p<T>(1, 1, _z); m.p<T>(1, 1, _z);
auto temp = right.transpose(); auto temp = right.transpose();
left.assign(mmul(rotation, *temp)); left.assign(mmul(rotation, temp));
delete temp;
} }
@ -289,7 +286,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
else if(_rows < _cols) { else if(_rows < _cols) {
auto matrixT = matrix.transpose(); auto matrixT = matrix.transpose();
HHcolPivQR qr(*matrixT / scale); HHcolPivQR qr(matrixT / scale);
_m.assign(qr._qr({0,_rows, 0,_rows})); _m.assign(qr._qr({0,_rows, 0,_rows}));
_m.fillAsTriangular<T>(0., 0, 0, 'l'); _m.fillAsTriangular<T>(0., 0, 0, 'l');
_m.transposei(); _m.transposei();
@ -305,8 +302,6 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
if(_calcU) if(_calcU)
_u.assign(qr._permut); _u.assign(qr._permut);
delete matrixT;
} }
else { else {
@ -352,8 +347,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
if(_calcU) { if(_calcU) {
auto temp = rotLeft.transpose(); auto temp = rotLeft.transpose();
mulRotationOnRight(p, q, _u, *temp); mulRotationOnRight(p, q, _u, temp);
delete temp;
} }
mulRotationOnRight(p, q, _m, rotRight); mulRotationOnRight(p, q, _m, rotRight);

View File

@ -920,7 +920,7 @@ void SVD<T>::evalData(const NDArray& matrix) {
auto temp1 = biDiag._HHbidiag.transpose(); auto temp1 = biDiag._HHbidiag.transpose();
auto temp2 = _m({0,_diagSize, 0,0}, true); auto temp2 = _m({0,_diagSize, 0,0}, true);
temp2.assign(temp1); temp2.assign(temp1);
delete temp1;
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true); auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
temp3.assign(0.); temp3.assign(0.);

View File

@ -184,9 +184,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
if(pC->ordering() != 'f') { if(pC->ordering() != 'f') {
auto temp = pA; auto temp = pA;
pA = pB ->permute({1,0}); pA = new NDArray(pB ->permute({1,0}));
pB = temp->permute({1,0}); pB = new NDArray(temp->permute({1,0}));
pC = pC ->permute({1,0}); pC = new NDArray(pC ->permute({1,0}));
toDelete.push_back(pA); toDelete.push_back(pA);
toDelete.push_back(pB); toDelete.push_back(pB);
toDelete.push_back(pC); toDelete.push_back(pC);
@ -251,7 +251,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
} }
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES)
} }
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status); if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
@ -339,7 +340,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
threadsPerBlock.x = 512; threadsPerBlock.x = 512;
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
} }
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES)
} }
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status); if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
@ -396,7 +398,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
NDArray::prepareSpecialUse({Z}, {X, Y}); NDArray::prepareSpecialUse({Z}, {X, Y});
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream); auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult); if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@ -406,8 +409,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
return Z; return Z;
} }
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES); //BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
} }

View File

@ -28,33 +28,27 @@
namespace nd4j { namespace nd4j {
nd4j::NDArray * nd4j::NDArray AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
auto miniBatchSize = input->sizeAt(0); auto miniBatchSize = input->sizeAt(0);
auto seqLength = input->sizeAt(2); auto seqLength = input->sizeAt(2);
auto numHeads = projectionMatrix->sizeAt(0); auto numHeads = projectionMatrix->sizeAt(0);
auto projectedSize = projectionMatrix->sizeAt(1); auto projectedSize = projectionMatrix->sizeAt(1);
auto inputPerm = input->permute({1, 0, 2}); auto inputPerm = input->permute({1, 0, 2});
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
NDArray* projected = new NDArray('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
nd4j::ops::matmul mmul; nd4j::ops::matmul mmul;
mmul.execute({projectionPrep, inputPrep}, {projected}, {}, {}, {}); mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
projected->reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected->permutei({2, 0, 1, 3}); projected.permutei({2, 0, 1, 3});
delete inputPerm;
delete inputPrep;
delete projectionPrep;
return projected; return projected;
} }
void void AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
const nd4j::NDArray *eps, nd4j::NDArray *dLdInput, const nd4j::NDArray *eps, nd4j::NDArray *dLdInput,
nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) { nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) {
auto miniBatchSize = input->sizeAt(0); auto miniBatchSize = input->sizeAt(0);
@ -63,16 +57,16 @@ namespace nd4j {
auto projectedSize = projectionMatrix->sizeAt(1); auto projectedSize = projectionMatrix->sizeAt(1);
auto epsPerm = eps->permute({1, 2, 0, 3}); auto epsPerm = eps->permute({1, 2, 0, 3});
auto epsReshaped = epsPerm->reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength}); auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
auto inputPerm = input->permute({1, 0, 2}); auto inputPerm = input->permute({1, 0, 2});
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
nd4j::ops::matmul_bp mmulBp; nd4j::ops::matmul_bp mmulBp;
NDArray dLdProjectionPrep(projectionPrep->shapeInfo(), false, context); NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
NDArray dLdInputPrep(inputPrep->shapeInfo(), false, context); NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
mmulBp.execute({projectionPrep, inputPrep, epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {}); mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
dLdProjectionMatrix->assign(dLdProjectionPrep); dLdProjectionMatrix->assign(dLdProjectionPrep);
@ -80,12 +74,6 @@ namespace nd4j {
dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength}); dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
dLdInputPrep.permutei({1, 0, 2}); dLdInputPrep.permutei({1, 0, 2});
dLdInput->assign(dLdInputPrep); dLdInput->assign(dLdInputPrep);
delete inputPerm;
delete inputPrep;
delete epsPerm;
delete epsReshaped;
delete projectionPrep;
} }
} }

View File

@ -53,7 +53,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) { const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
const int numInArrsFF = argsHolderFF.getNumInArrs(); // also numInArrsFF = number of output arrays in opBP const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs(); const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs(); const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
@ -65,6 +65,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF; ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0 NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
for(int i = 0; i < numInArrsFF; ++i) { // loop through input array for(int i = 0; i < numInArrsFF; ++i) { // loop through input array
if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false) if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false)
@ -75,39 +76,39 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array
double& elem = inArrsFF[i]->t<double>(j); const double orig = inArrsFF[i]->e<double>(j);
const double orig = elem;
// add epsilon, feed forward // add epsilon, feed forward
elem = orig + EPSILON; inArrsFF[i]->p<double>(j, orig + EPSILON);
ResultSet* outArrsFF = opFF.execute(argsHolderFF); ResultSet* outArrsFF = opFF.execute(argsHolderFF);
int numOutArrs = outArrsFF->size(); int numOutArrs = outArrsFF->size();
double scorePlus = 0.; double scorePlus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo()); outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
else else
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo()); outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0); scorePlus += tmpScalar.e<double>(0);
} }
delete outArrsFF; delete outArrsFF;
// subtract epsilon, feed forward // subtract epsilon, feed forward
elem = orig - EPSILON; inArrsFF[i]->p<double>(j, orig - EPSILON);
outArrsFF = opFF.execute(argsHolderFF); outArrsFF = opFF.execute(argsHolderFF);
double scoreMinus = 0.; double scoreMinus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output array for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM) if(loss == SUM)
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo()); outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
else else
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo()); outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0); scoreMinus += tmpScalar.e<double>(0);
} }
delete outArrsFF; delete outArrsFF;
// restore initial element value // restore initial element value
elem = orig; inArrsFF[i]->p<double>(j, orig);
// calculate numerical gradient // calculate numerical gradient
const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON); const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON);

View File

@ -43,22 +43,19 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray* aPR = a->permute(permutAt); NDArray aPR = a->permute(permutAt);
NDArray* bPR = b->permute(permutBt); NDArray bPR = b->permute(permutBt);
// check whether reshape is necessary // check whether reshape is necessary
if(!aPR->isSameShape(shapeAt)) if(!aPR.isSameShape(shapeAt))
aPR->reshapei( shapeAt); aPR.reshapei( shapeAt);
if(!bPR->isSameShape(shapeBt)) if(!bPR.isSameShape(shapeBt))
bPR->reshapei( shapeBt); bPR.reshapei( shapeBt);
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape); c->reshapei(outShape);
delete aPR;
delete bPR;
return c; return c;
} }
@ -74,21 +71,21 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
// check whether permutation is required // check whether permutation is required
if(!permutForC.empty()) if(!permutForC.empty())
cP = c->permute(permutForC); cP = new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt); auto aPR = a->permute(permutAt);
auto bPR = b->permute(permutBt); auto bPR = b->permute(permutBt);
// check whether reshape is necessary // check whether reshape is necessary
if(!aPR->isSameShape(shapeAt)) if(!aPR.isSameShape(shapeAt))
aPR->reshapei(shapeAt); aPR.reshapei(shapeAt);
if(!bPR->isSameShape(shapeBt)) if(!bPR.isSameShape(shapeBt))
bPR->reshapei(shapeBt); bPR.reshapei(shapeBt);
if(!cP->isSameShape({aPR->sizeAt(0), bPR->sizeAt(1)})) if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
cPR = cP->reshape(cP->ordering(), {aPR->sizeAt(0), bPR->sizeAt(1)}); cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
mmul(aPR, bPR, cPR, 1.0, 0.0); mmul(&aPR, &bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer() if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR); cP->assign(cPR);
@ -97,40 +94,42 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
delete cPR; delete cPR;
if(cP != c) if(cP != c)
delete cP; delete cP;
delete aPR;
delete bPR;
} }
#ifndef __JAVACPP_HACK__ #ifndef __JAVACPP_HACK__
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) { void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) {
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b)); NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception
for(const auto& arr : modifA) for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB) for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
for(const auto& arr : modifC) for(const auto& arr : modifC)
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r"; whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
// first step for a array // first step for a array
if(!whatToDoWithA.empty()) if(!whatToDoWithA.empty())
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]); aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
// first step for b array // first step for b array
if(!whatToDoWithB.empty()) if(!whatToDoWithB.empty())
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]); bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
// rest steps for a array // rest steps for a array
for(int i = 1; i < whatToDoWithA.size(); ++i) for(int i = 1; i < whatToDoWithA.size(); ++i)
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
// rest steps for b array // rest steps for b array
for(int i = 1; i < whatToDoWithB.size(); ++i) for(int i = 1; i < whatToDoWithB.size(); ++i)
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]); if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
// now work with c array // now work with c array
std::vector<NDArray*> cArrs = {c}; std::vector<NDArray*> cArrs = {c};
if(!whatToDoWithC.empty()) { if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c); cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i) for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? cArrs[i]->permute(modifC[i]) : cArrs[i]->reshape(c->ordering(), modifC[i]); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
} }
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
@ -152,18 +151,21 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) { NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) {
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b)); NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception
for(const auto& arr : modifA) for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB) for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r"; whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
// first step for a array // first step for a array
if(!whatToDoWithA.empty()) if(!whatToDoWithA.empty())
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]); aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
// first step for b array // first step for b array
if(!whatToDoWithB.empty()) if(!whatToDoWithB.empty())
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]); bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
// rest steps for a array // rest steps for a array
for(int i = 1; i < whatToDoWithA.size(); ++i) for(int i = 1; i < whatToDoWithA.size(); ++i)
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]); if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
@ -293,17 +295,17 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
permut[rank-1] = rank - 2; permut[rank-1] = rank - 2;
if(transX) if(transX)
xT = x->permute(permut); xT = new NDArray(x->permute(permut));
if(transY) if(transY)
yT = y->permute(permut); yT = new NDArray(y->permute(permut));
} }
if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases
if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case
xT = x->reshape(x->ordering(), {1, x->lengthOf()}); // please note x is not transposed in this case (since xRank=1) xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1)
zT = z->reshape(z->ordering(), {1, z->lengthOf()}); zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
} }
mmul(xT, yT, zT, 1., 0.); mmul(xT, yT, zT, 1., 0.);

View File

@ -473,19 +473,9 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
// FIXME: get rid of memcpy here // FIXME: get rid of memcpy here
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank)); memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
for (int i = 0; i < minRank; ++i) for (int i = 0; i < minRank; ++i)
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
// nullify zero axis
for (int e = 0; e < maxRank; e++)
if (maxShapeInfo[e+1] == 0)
tmpShapeInfo[e+1] = 0;
int delta = maxRank - minRank;
for (int e = minRank - 1; e >= 0; e--)
if (minShapeInfo[e + 1] == 0)
tmpShapeInfo[e + 1 + delta] = 0;
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) { if (shape::isEmpty(max) || shape::isEmpty(min)) {

View File

@ -40,7 +40,7 @@ namespace nd4j {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __host__
#endif #endif
void Logger::printv(const char *format, std::vector<int>& vec) { void Logger::printv(const char *format, const std::vector<int>& vec) {
printf("%s: {", format); printf("%s: {", format);
for(int e = 0; e < vec.size(); e++) { for(int e = 0; e < vec.size(); e++) {
auto v = vec[e]; auto v = vec[e];
@ -55,7 +55,7 @@ namespace nd4j {
#ifdef __CUDACC__ #ifdef __CUDACC__
__host__ __host__
#endif #endif
void Logger::printv(const char *format, std::vector<Nd4jLong>& vec) { void Logger::printv(const char *format, const std::vector<Nd4jLong>& vec) {
printf("%s: {", format); printf("%s: {", format);
for(int e = 0; e < vec.size(); e++) { for(int e = 0; e < vec.size(); e++) {
auto v = vec[e]; auto v = vec[e];

View File

@ -55,8 +55,8 @@ namespace nd4j {
static void _CUDA_H info(const char *format, ...); static void _CUDA_H info(const char *format, ...);
static void _CUDA_H printv(const char *format, std::vector<int>& vec); static void _CUDA_H printv(const char *format, const std::vector<int>& vec);
static void _CUDA_H printv(const char *format, std::vector<Nd4jLong>& vec); static void _CUDA_H printv(const char *format, const std::vector<Nd4jLong>& vec);
}; };
} }

View File

@ -1023,23 +1023,6 @@ namespace shape {
*/ */
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false); ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
/**
* insert dimension at shape[axis] position
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, dimension = 10 result is -> shape = {2,10,4,5}
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 3, dimension = 10 result is -> shape = {2,4,5,10}
* so be careful and provide shape buffer with enough (at least rank+1) length
* axis should be within [0, rank] range
*/
ND4J_EXPORT _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension);
/**
* erase dimension at shape[axis] position
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, result is -> shape = {2,5}
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 2, result is -> shape = {2,4}
* axis should be within [0, rank-1] range
*/
ND4J_EXPORT _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis);
@ -4932,21 +4915,6 @@ INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffs
} }
} }
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension) {
for (int i = rank; i > axis; --i)
shape[i] = shape[i - 1];
shape[axis] = dimension;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis) {
for (int i = axis; i < rank - 1; ++i)
shape[i] = shape[i + 1];
}
} }

View File

@ -244,8 +244,9 @@ namespace functions {
auto xi = x + threadOffset; auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++) for (Nd4jLong i = 0; i < ulen; i++) {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams); local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
}
PRAGMA_OMP_CRITICAL PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams); startingVal = OpType::update(startingVal, local, extraParams);

View File

@ -122,7 +122,7 @@ namespace functions {
tadLength = shape::length(tadOnlyShapeInfo); tadLength = shape::length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength; numTads = shape::length(yShapeInfo) / tadLength;
xEWS = shape::elementWiseStride(xShapeInfo); xEWS = shape::elementWiseStride(xShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
} }

View File

@ -21,12 +21,165 @@
#include <ops/specials_cuda.h> #include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int half = window>>1;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
}
__syncthreads();
//for (int i = 0; i < length; i+= window)
/*
if window == 4;
iterations will be: 0; 4; 8; 12; 16; 20
if gridDim = 3;
on first iteration we'll have: 0; 4; 8;
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
*/
int firstPosition;
int firstStep;
int secondPosition;
int secondStep;
int WARP_SIZE = 32;
int numWarps = (gridDim.x * blockDim.x) / 32;
int warpId = tid / WARP_SIZE;
int warpIdx = tid % WARP_SIZE;
if (half >= 128) {
firstPosition = blockIdx.x * window;
firstStep = gridDim.x * window;
secondPosition = threadIdx.x;
secondStep = blockDim.x;
} else if (half >= 32) {
firstPosition = warpId * window;
firstStep = numWarps * window;
secondPosition = warpIdx;
secondStep = WARP_SIZE;
} else {
firstPosition = tid * window;
firstStep = blockDim.x * gridDim.x * window;
secondPosition = 0;
secondStep = 1;
}
for (int i = firstPosition; i < length; i += firstStep) {
for (int j = secondPosition; j < half; j += secondStep) {
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
Y v0 = y[posIJ];
Y v1 = y[posIT];
if(!descending == (v0 > v1)) {
y[posIJ] = v1;
y[posIT] = v0;
X xtemp = x[posIJ];
x[posIJ] = x[posIT];
x[posIT] = xtemp;
}
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int half = window>>1;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
}
__syncthreads();
//for (int i = 0; i < length; i+= window)
/*
if window == 4;
iterations will be: 0; 4; 8; 12; 16; 20
if gridDim = 3;
on first iteration we'll have: 0; 4; 8;
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
*/
int firstPosition;
int firstStep;
int secondPosition;
int secondStep;
int WARP_SIZE = 32;
int numWarps = (gridDim.x * blockDim.x) / 32;
int warpId = tid / WARP_SIZE;
int warpIdx = tid % WARP_SIZE;
if (half >= 128) {
firstPosition = blockIdx.x * window;
firstStep = gridDim.x * window;
secondPosition = threadIdx.x;
secondStep = blockDim.x;
} else if (half >= 32) {
firstPosition = warpId * window;
firstStep = numWarps * window;
secondPosition = warpIdx;
secondStep = WARP_SIZE;
} else {
firstPosition = tid * window;
firstStep = blockDim.x * gridDim.x * window;
secondPosition = 0;
secondStep = 1;
}
for (int i = firstPosition; i < length; i += firstStep) {
for (int j = secondPosition; j < half; j += secondStep) {
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
X v0 = x[posIJ];
X v1 = x[posIT];
if(!descending == (v0 > v1)) {
x[posIJ] = v1;
x[posIT] = v0;
Y ytemp = y[posIJ];
y[posIJ] = y[posIT];
y[posIT] = ytemp;
}
}
}
}
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__device__ __global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<T*>(vx); auto x = static_cast<T*>(vx);
int tid = threadIdx.x + blockDim.x * blockIdx.x; int tid = threadIdx.x + blockDim.x * blockIdx.x;
@ -85,8 +238,8 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
int it = (reverse) ? i + j + half : i + window - j - 1; int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j; int ij = i+j;
if (it < length && ij < length ) { if (it < length && ij < length ) {
int posIT = getDevicePosition(xShapeInfo,it, xLength); int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
int posIJ = getDevicePosition(xShapeInfo, ij, xLength); int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
shmem[threadIdx.x] = x[posIJ]; shmem[threadIdx.x] = x[posIJ];
shmem[threadIdx.x + blockDim.x] = x[posIT]; shmem[threadIdx.x + blockDim.x] = x[posIT];
@ -100,18 +253,22 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
} }
} }
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernel<T>(vx, xShapeInfo, window, length, reverse, descending);
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) { __host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending); execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending);
nd4j::DebugHelper::checkErrorCode(stream, "bitonicArbitrary(...) failed");
} }
template <typename X, typename Y>
__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
template <typename X, typename Y>
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -21,9 +21,119 @@
#include <ops/specials_cuda.h> #include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
unsigned int i, ixj; /* Sorting partners: i and ixj */
i = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0)
xLength = shape::length(xShapeInfo);
__syncthreads();
if (i >= length)
return;
ixj = i^j;
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
if (!descending == (y[posI]>y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
} else if ((i&k)!=0) {
/* Sort descending */
if (!descending == (y[posI]<y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
unsigned int i, ixj; /* Sorting partners: i and ixj */
i = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0)
xLength = shape::length(xShapeInfo);
__syncthreads();
if (i >= length)
return;
ixj = i^j;
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
if (!descending == (x[posI]>x[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
} else if ((i&k)!=0) {
/* Sort descending */
if (!descending == (x[posI]<x[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
}
}
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) { __global__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<T*>(vx); auto x = static_cast<T*>(vx);
@ -44,8 +154,8 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
/* The threads with the lowest ids sort the array. */ /* The threads with the lowest ids sort the array. */
if ((ixj)>i) { if ((ixj)>i) {
int posI = getDevicePosition(xShapeInfo, i, xLength); int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
int posIXJ = getDevicePosition(xShapeInfo, ixj, xLength); int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
if ((i&k)==0) { if ((i&k)==0) {
/* Sort ascending */ /* Sort ascending */
@ -69,16 +179,23 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ void execBitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) { __host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
bitonicSortStepKernel<T>(vx, xShapeInfo, j, k, length, descending);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template <typename X, typename Y>
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) { __host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
execBitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
nd4j::DebugHelper::checkErrorCode(stream, "bitonicSortStep(...) failed");
} }
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -16,15 +16,86 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma, created on 28.11.2018
// //
#include <ops/specials_cuda.h> #include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
__shared__ int xLength;
__shared__ int xTadLength;
__shared__ int numTads;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
xTadLength = shape::length(tadShapeInfo);
numTads = xLength / xTadLength;
}
__syncthreads();
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
auto dx = x + tadOffsets[r];
auto dy = y + tadOffsets[r];
// this is general loop, we go uncached
int iterations = xTadLength;
for (int i = 0; i < iterations; i++) {
if (i % 2 == 0) {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1;
if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0];
dx[t0] = dx[t1];
dx[t1] = dt0;
Y dy0 = dy[t0];
dy[t0] = dy[t1];
dy[t1] = dy0;
}
}
}
} else {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2;
if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0];
dx[t0] = dx[t1];
dx[t1] = dt0;
Y dy0 = dy[t0];
dy[t0] = dy[t1];
dy[t1] = dy0;
}
}
}
}
__syncthreads();
}
}
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__device__ __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) { bool descending) {
@ -56,7 +127,7 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int iterations = xTadLength; int iterations = xTadLength;
if (cached) { if (cached) {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength); auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
shmem[tid] = dx[t0]; shmem[tid] = dx[t0];
} }
@ -70,8 +141,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1; auto top = 2 * tid + 1;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength); auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength); auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0]; T dt0 = dx[t0];
@ -84,8 +155,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2; auto top = 2 * tid + 2;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength); auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength); auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0]; T dt0 = dx[t0];
@ -102,23 +173,13 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
if (cached) { if (cached) {
dx = x + tadOffsets[r]; dx = x + tadOffsets[r];
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength); auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
dx[t0] = shmem[tid]; dx[t0] = shmem[tid];
} }
} }
} }
} }
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
oesTadKernel<T>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream, __host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
@ -128,6 +189,18 @@ __host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
bool descending) { bool descending) {
execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending); execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
nd4j::DebugHelper::checkErrorCode(stream, "oesTad(...) failed");
} }
template <typename X, typename Y>
__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
execOesTadKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -65,13 +65,7 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
// ***** end of validation ***** // // ***** end of validation ***** //
if(alphaShape != expectedAlphaShape) helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output);
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
helpers::prelu(block.launchContext(), *input, *alpha, *output);
if(alphaShape != expectedAlphaShape)
delete alpha;
return Status::OK(); return Status::OK();
} }
@ -128,9 +122,10 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str()); REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
// ***** end of validation ***** // // ***** end of validation ***** //
if(alphaShape != expectedAlphaShape) { if(alphaShape != expectedAlphaShape) {
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape); alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape));
dLdA = dLdA->reshape(dLdA->ordering(), expectedAlphaShape); dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape));
} }
helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA); helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA);

View File

@ -29,7 +29,6 @@ namespace nd4j {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
nd4j_printf("Comparing [%f] to [%f]\n", x->e<float>(0), y->e<float>(0));
if (x->e<float>(0) < y->e<float>(0)) if (x->e<float>(0) < y->e<float>(0))
return ND4J_STATUS_TRUE; return ND4J_STATUS_TRUE;
else else

View File

@ -31,7 +31,7 @@ namespace nd4j {
auto condition = INPUT_VARIABLE(0); auto condition = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
if (z->isEmpty()) if (z->isEmpty())
return ND4J_STATUS_OK; return Status::OK();
if (block.width() == 3) { if (block.width() == 3) {
auto x = INPUT_VARIABLE(1); auto x = INPUT_VARIABLE(1);
@ -44,12 +44,10 @@ namespace nd4j {
// FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y // FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y
for (int e = 0; e < condition->lengthOf(); e++) { for (int e = 0; e < condition->lengthOf(); e++) {
if (y->isR()) { if (y->isR()) {
auto r = !condition->e<bool>(e) ? y->e<double>(e) auto r = !condition->e<bool>(e) ? y->e<double>(e) : x->e<double>(e);
: x->e<double>(e);
z->p(e, r); z->p(e, r);
} else { } else {
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e) auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e) : x->e<Nd4jLong>(e);
: x->e<Nd4jLong>(e);
z->p(e, r); z->p(e, r);
} }
} }
@ -86,7 +84,7 @@ namespace nd4j {
helpers::_where(block.launchContext(), *condition, *output, block.workspace()); helpers::_where(block.launchContext(), *condition, *output, block.workspace());
} }
return ND4J_STATUS_OK; return Status::OK();
} }
DECLARE_SHAPE_FN(Where) { DECLARE_SHAPE_FN(Where) {

View File

@ -120,7 +120,7 @@ namespace nd4j {
} }
} }
return ND4J_STATUS_OK; return Status::OK();
} }
DECLARE_SHAPE_FN(where_np) { DECLARE_SHAPE_FN(where_np) {

View File

@ -81,11 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput); auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete outputReshaped;
delete weightsReshaped;
return Status::OK(); return Status::OK();
} }
@ -217,13 +213,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete gradIReshaped;
delete gradOReshaped;
delete weightsReshaped;
delete gradWReshaped;
return Status::OK(); return Status::OK();
} }

View File

@ -151,10 +151,10 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
std::vector<int> permutForOutput; std::vector<int> permutForOutput;
if(!isNCDHW) if (isNCDHW)
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
else
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
else
input = new NDArray(input->permute({0,4,1,2,3}));
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
@ -447,21 +447,23 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
std::vector<int> gradOaxesForDot; std::vector<int> gradOaxesForDot;
if(!isNDHWC) { if(!isNDHWC) {
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
} }
else else {
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
}
// ----- calculation of gradW and gradB ----- // // ----- calculation of gradW and gradB ----- //
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
//----- calculation of gradO -----//
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;

View File

@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(!isNCHW) if(!isNCHW)
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -211,8 +211,9 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// -----prepare permutation arrays and axes for dot product ----- // // -----prepare permutation arrays and axes for dot product ----- //
std::vector<int> inputAxesForDot; std::vector<int> inputAxesForDot;
if(!isNCHW) { if(!isNCHW) {
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
inputAxesForDot = {0, 1, 2}; // bS, iH, iW inputAxesForDot = {0, 1, 2}; // bS, iH, iW
} }
else else
@ -228,7 +229,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
@ -237,7 +238,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
if(!isNCHW) if(!isNCHW)
delete gradO; delete gradO;
return ND4J_STATUS_OK; return Status::OK();
} }
DECLARE_SHAPE_FN(deconv2d_bp) { DECLARE_SHAPE_FN(deconv2d_bp) {

View File

@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(!isNCDHW) if(!isNCDHW)
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
@ -225,8 +225,9 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// -----prepare permutation arrays and axes for dot product ----- // // -----prepare permutation arrays and axes for dot product ----- //
std::vector<int> inputAxesForDot; std::vector<int> inputAxesForDot;
if(!isNCDHW) { if(!isNCDHW) {
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW
} }
else else
@ -240,7 +241,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;

View File

@ -71,7 +71,7 @@ namespace ops {
int pad_top = 0, pad_left = 0; int pad_top = 0, pad_left = 0;
int out_rows = 0, out_cols = 0; int out_rows = 0, out_cols = 0;
helpers::_dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols); helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols); REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
@ -126,7 +126,7 @@ namespace ops {
int pad_top = 0, pad_left = 0; int pad_top = 0, pad_left = 0;
int out_rows = 0, out_cols = 0; int out_rows = 0, out_cols = 0;
helpers::_dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols); helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}}; std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data()); newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());

View File

@ -60,8 +60,8 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
@ -71,7 +71,6 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
//output->printBuffer("output op");
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;
@ -177,10 +176,11 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME
@ -205,9 +205,6 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
delete gradI; delete gradI;
delete gradO; delete gradO;
} }
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK(); return Status::OK();

View File

@ -61,8 +61,8 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
if(!isNCDHW) { if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME
@ -180,9 +180,9 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) { if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME

View File

@ -60,8 +60,8 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2); const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2);
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
@ -175,9 +175,9 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME
@ -203,9 +203,6 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
delete gradI; delete gradI;
delete gradO; delete gradO;
} }
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK(); return Status::OK();
} }

View File

@ -63,8 +63,8 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
if(!isNCDHW) { if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME
@ -182,9 +182,9 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) { if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
} }
if(isSameMode) // SAME if(isSameMode) // SAME
@ -211,9 +211,6 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
delete gradI; delete gradI;
delete gradO; delete gradO;
} }
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK(); return Status::OK();
} }

View File

@ -55,8 +55,8 @@ namespace nd4j {
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
const auto inY = static_cast<int>(input->sizeAt(2)); const auto inY = static_cast<int>(input->sizeAt(2));
@ -175,9 +175,9 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) { if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
} }
// if(isSameMode) // SAME // if(isSameMode) // SAME
@ -216,10 +216,6 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
delete gradI; delete gradI;
delete gradO; delete gradO;
} }
// delete columns;
// delete columns2d;
// delete gradOVector;
// delete denomVec;
return Status::OK(); return Status::OK();
} }

View File

@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(&E)) { if(!weights->isScalar() && !weights->isSameShape(&E)) {
if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1) if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1)
weightsBroad = weights->reshape(weights->ordering(), {weights->lengthOf()}); weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()}));
else else
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
} }

View File

@ -74,7 +74,7 @@ namespace ops {
} }
if(mask != nullptr){ if(mask != nullptr){
NDArray* reshapedMask; NDArray reshapedMask;
if(weights->rankOf() == 4){ if(weights->rankOf() == 4){
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
}else{ }else{
@ -87,8 +87,7 @@ namespace ops {
// before going through the softmax, we effectively push all masked positions to zero after softmax. // before going through the softmax, we effectively push all masked positions to zero after softmax.
// //
// we are using 1e9 to mean effectively infinity // we are using 1e9 to mean effectively infinity
*weights += (*reshapedMask - 1) * 1e9; *weights += (reshapedMask - 1) * 1e9;
delete reshapedMask;
} }
nd4j::ops::softmax softmax; nd4j::ops::softmax softmax;
@ -175,14 +174,13 @@ namespace ops {
preSoftmax /= factor; preSoftmax /= factor;
if(mask != nullptr){ if(mask != nullptr){
NDArray* reshapedMask; NDArray reshapedMask;
if(preSoftmax.rankOf() == 4){ if(preSoftmax.rankOf() == 4){
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1}); reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
}else{ }else{
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1}); reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1});
} }
preSoftmax += (*reshapedMask - 1) * 1e9; preSoftmax += (reshapedMask - 1) * 1e9;
delete reshapedMask;
} }
NDArray weights('c', weightShape, values->dataType(), block.launchContext()); NDArray weights('c', weightShape, values->dataType(), block.launchContext());

View File

@ -70,7 +70,7 @@ namespace nd4j {
float beta = T_ARG(2); float beta = T_ARG(2);
int depth = INT_ARG(0); int depth = INT_ARG(0);
helpers::lrnBP(*input, *gradO, *gradI, depth, bias, alpha, beta); helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta);
return Status::OK(); return Status::OK();
} }

View File

@ -98,9 +98,9 @@ namespace ops {
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
// Apply Attention // Apply Attention
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext()); NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
nd4j::ops::dot_product_attention attention; nd4j::ops::dot_product_attention attention;
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {}); attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
// Project attention results // Project attention results
attnResults.permutei({0, 3, 1, 2}); attnResults.permutei({0, 3, 1, 2});
@ -111,11 +111,9 @@ namespace ops {
mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {}); mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {});
projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize}); projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize});
projRes.permutei({0, 2, 1}); projRes.permutei({0, 2, 1});
output->assign(projRes);
delete projectedQueries; // FIXME: bad for performance
delete projectedKeys; output->assign(projRes);
delete projectedValues;
return Status::OK(); return Status::OK();
} }
@ -227,9 +225,9 @@ namespace ops {
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
// Apply Attention // Apply Attention
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext()); NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
nd4j::ops::dot_product_attention attention; nd4j::ops::dot_product_attention attention;
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {}); attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
// Project attention results // Project attention results
attnResults.permutei({0, 3, 1, 2}); attnResults.permutei({0, 3, 1, 2});
@ -237,31 +235,25 @@ namespace ops {
// dLdWo // dLdWo
auto epsPerm = eps->permute({0, 2, 1}); auto epsPerm = eps->permute({0, 2, 1});
auto epsPostReshape = epsPerm->reshape(eps->ordering(), {miniBatchSize * queryCount, outSize}); auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
nd4j::ops::matmul_bp matmulBp; nd4j::ops::matmul_bp matmulBp;
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext()); NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
matmulBp.execute({&attnResults, Wo, epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {}); matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
// dLdAttn // dLdAttn
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues->sizeAt(2)}); dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
dLdPreWo.permutei({0, 2, 3, 1}); dLdPreWo.permutei({0, 2, 3, 1});
nd4j::ops::dot_product_attention_bp attentionBp; nd4j::ops::dot_product_attention_bp attentionBp;
NDArray dLdProjectedQueries(projectedQueries->shapeInfo(), false, block.launchContext()); NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext());
NDArray dLdProjectedKeys(projectedKeys->shapeInfo(), false, block.launchContext()); NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext());
NDArray dLdProjectedValues(projectedValues->shapeInfo(), false, block.launchContext()); NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext());
attentionBp.execute({projectedQueries, projectedKeys, projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {}); attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext()); AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext());
AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext()); AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext());
AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext()); AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext());
delete projectedQueries;
delete projectedKeys;
delete projectedValues;
delete epsPerm;
delete epsPostReshape;
return Status::OK(); return Status::OK();
} }

View File

@ -51,7 +51,7 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!"); REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!");
} }
*output = helpers::betaInc(block.launchContext(), *a, *b, *x); helpers::betaInc(block.launchContext(), *a, *b, *x, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -48,10 +48,7 @@ namespace nd4j {
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf()); //nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
auto tArr = input->reshape(input->ordering(), shape); auto tArr = input->reshape(input->ordering(), shape);
auto zArr = z->reshape(z->ordering(), shape); auto zArr = z->reshape(z->ordering(), shape);
tArr->addRowVector(bias, zArr); tArr.addRowVector(bias, &zArr);
delete tArr;
delete zArr;
} }
STORE_RESULT(*z); STORE_RESULT(*z);
@ -87,13 +84,12 @@ namespace nd4j {
// cnn case // cnn case
if (input->rankOf() == 4) { if (input->rankOf() == 4) {
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3}); auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
epsilonNext2d->reshapei('c', {(int) bias->lengthOf(), -1}); epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
auto sum = epsilonNext2d->reduceAlongDimension(reduce::Sum, {1}); auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
gradB->assign(sum); gradB->assign(sum);
delete sum; delete sum;
delete epsilonNext2d;
} else if (input->rankOf() == 2) { } else if (input->rankOf() == 2) {
// regular fully-connected case // regular fully-connected case
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0}); auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_check_numerics)
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto message = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite);
REQUIRE_TRUE(allFinite.e<bool>(0), 0, "CheckNumerics: %s", message->e<std::string>(0).c_str());
if (!block.isInplace())
output->assign(input);
return Status::OK();
}
DECLARE_SHAPE_FN(check_numerics) {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
}
DECLARE_TYPES(check_numerics) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, nd4j::DataType::UTF8)
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -56,7 +56,7 @@ namespace nd4j {
} }
DECLARE_SHAPE_FN(crop_and_resize) { DECLARE_SHAPE_FN(crop_and_resize) {
auto in = inputShape->at(0); auto in = inputShape->at(1);
Nd4jLong outputShape[4]; Nd4jLong outputShape[4];
@ -77,8 +77,13 @@ namespace nd4j {
} }
DECLARE_TYPES(crop_and_resize) { DECLARE_TYPES(crop_and_resize) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS}); // ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(1, {FLOAT32}) // as TF
->setAllowedInputTypes(2, {ALL_INTS})
->setAllowedInputTypes(3, {ALL_INTS})
->setAllowedOutputTypes({FLOAT32}); // as TF
// ->setAllowedOutputTypes({ALL_FLOATS});
} }
} }
} }

View File

@ -47,9 +47,9 @@ namespace ops {
auto o = OUTPUT_VARIABLE(0); auto o = OUTPUT_VARIABLE(0);
if (a->lengthOf() == 3) { if (a->lengthOf() == 3) {
helpers::_cross(block.launchContext(), a, b, o); helpers::cross(block.launchContext(), a, b, o);
} else { } else {
helpers::_crossBatched(block.launchContext(), a, b, o); helpers::crossBatched(block.launchContext(), a, b, o);
} }
return Status::OK(); return Status::OK();

Some files were not shown because too many files have changed in this diff Show More