Merge master to upstream (#7945)

* Shugeo strided slice zeros (#14)

* Modified strided_slice op to properly work with empty-like shapes.

* Fixed test for reduce_mean with empty-like input.

* [WIP] Last merge (#15)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks

* [WIP] Fixing outstanding issues for NLP (#9)

* Avoid using not-inited objects

* Test fixed.

* Redundant method avoided for models like FastText

* KMeans++ implementation

* KMeans++ implementation

* Disable parallel execution

* KMeans++

* Tests

* Dev branch merge (#16)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Fix some issues on master (#17)

* Fix DataVec test issue

* Fix issue with dl4j SameDiff output layer

* Dtype fix for lambda layers

* #7912 BertIterator dtype fix (use float32 not global default)

* [WIP] Next set of CUDA stuff (#7)

New CUDA implementations and improvements

* bad file

* Dev branch master merge (#23)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* SameDiff ops, TF import and fixes (#24)

* CheckNumerics tests + fixes + misc fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fake quant

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FakeQuantWithMinMaxArgs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* CheckNumerics fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Exception tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for out of scope stack allocated var use

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ignore for known failing test (already logged issue)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Merge upstream to fork (#25)

* Add thousand-separator commas to TotalParams (#7915)

* Add thousand-separator commas to TotalParams

The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.

* Add thousand-separator commas to MultiLayerNetwork

Corresponding change to MultiLayerNetwork

Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>

* Update contributing and issue/PR templates (#7934)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix link to AdaDelta paper (#7942)

Fix link to AdaDelta paper hosted on matthewzeiler.com

Signed-off-by: Jxtps

* Fixes, and ignores for known/logged failing issues (#7943)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff + DL4J/SameDiff: Multiple fixes (#28)

* #7919 HDF5 attribute buffer length fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7909 Arbiter constructor exception ux improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7925 RNN output layer length checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Add listener for validating inputs are not incorrectly modified

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #7939 Integrate NonInplaceValidationListener into tests

* #7844 DL4J SameDiff fixes for variable minibatch size

* DL4J SameDiff fixes - ensure gradient for input placeholder is available

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweaks to ExternalErrorsFunction - use placeholders, make more robust

* Another fix

* More fixes

* More SameDiff/DL4J fixes

* Scope out scalar array creation in BaseScalarOp

* Remove debug code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] Final dev branch merge (#29)

* SameDiff: convertDataType and gradient check util improvements (#12)

* GradCheck util improvements

* StopGradient constructor + test

* SameDiff: Add datatype conversion

* Javadoc and add DataType.isNumerical()

* Small fix

* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)

* TFGraphTestAllHelper: check intermediates in execution order

* Add missing debug listener

* [WIP] lstmBlock fix + other changes (#13)

- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite

* Small test fix

* CheckNumerics op wrapper

* Compatibility of deserialization (#18)

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* SameDiff: add activation gradient checking support for debugging (#19)

* SameDiff gradient checker: first pass on activation gradient checks

* Fixes + tests for activation gradient checking

* Javadoc

* [WIP] Some nd4j data type corrections (#20)

* Adjust data type

* Set correct Data type.

* Size of proper data type.

* fix averaged cpu load (#22)

* [WIP] Multiple dataset iterators (#27)

* Splitting dataset into arbitrary number

* Fixes

* Multiple split of iterator

* Test

* Test

* Some fixes

* signature change

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more test for sequential use of DataSetIteratorSplitter

Signed-off-by: raver119 <raver119@gmail.com>

* Fixes

* Fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* one more test for Alexander

Signed-off-by: raver119 <raver119@gmail.com>

* minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Some fixes

* Some fixes

* couple of assertions tweaked

Signed-off-by: raver119 <raver119@gmail.com>

* MDS splitter test :/

Signed-off-by: raver119 <raver119@gmail.com>

* Minor refactoring

* Multi dataset

* Some fixes

* More tests

* Small number of test fixes/improvements (failures on CI) (#31)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* [WIP] More CUDA stuff (#26)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* LRN BP CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* less memory

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed bug with crop_and_resize op helper.

* get rid of unnecessary index-calculation dunction

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed sort with nth_element cuda-based helper.

* Refactored nth_element.

* Refactored nth_element op and tests.

* Modified usage of dim array with sortTad routine.

* Refactored main routine of helper for non_max_image_suppression op.

* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.

* fix vol2col cuda kernel

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* topK concept

Signed-off-by: raver119 <raver119@gmail.com>

* unsorted topK with scanWitdh of 1

Signed-off-by: raver119 <raver119@gmail.com>

* correct vol2col tests

* sorted/unsorted topK

Signed-off-by: raver119 <raver119@gmail.com>

* implementation and fixing col2im/col2vol

* Corrected usage flags with input/output with reverse op.

* dup is const now

Signed-off-by: raver119 <raver119@gmail.com>

* percentile op

Signed-off-by: raver119 <raver119@gmail.com>

* group tests for mapool2d

Signed-off-by: Yurii <yurii@skymind.io>

* special test for george

Signed-off-by: raver119 <raver119@gmail.com>

* less threads for sortTad

Signed-off-by: raver119 <raver119@gmail.com>

* provide conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* remove auther in sort tad kernel code

Signed-off-by: Yurii <yurii@skymind.io>

* provide depthwise_conv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - max_pooling_with_argmax
- null check for special use

Signed-off-by: raver119 <raver119@gmail.com>

* dts cuda

Signed-off-by: raver119 <raver119@gmail.com>

* provide sconv2d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* std cuda

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op to conform TF implementation.

* Improved suppression helper.

* provide pooling3d for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* more of minor lstm rearrangements

Signed-off-by: raver119 <raver119@gmail.com>

* (bi)dynamic_rnn

Signed-off-by: raver119 <raver119@gmail.com>

* templates init order

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression op.

* Added cuda kernel for non_max_suppression.

* CPU sort by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value

Signed-off-by: raver119 <raver119@gmail.com>

* CPU sort TAD by key/value tests

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminate compiler error with cuda implementation.

* - repaired gradCheck in cuda
- provide conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* missed signature

Signed-off-by: raver119 <raver119@gmail.com>

* provide depthwise_conv2d_bp for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of lup helper with cuda kernel. Initial commit.

* further work on backprops for convolutions

Signed-off-by: Yurii <yurii@skymind.io>

* CUDA linear sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA tad sort by key/val

Signed-off-by: raver119 <raver119@gmail.com>

* start providing of backprop for pooling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* Added atomicAdd for bool datatype.

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic partition scalar CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* important comment

Signed-off-by: raver119 <raver119@gmail.com>

* fix pooling2d/3d backprop helpers

Signed-off-by: Yurii <yurii@skymind.io>

* Added non-linear test with dynamic_partition.

* Improved test for dynamic_partition.

* dynamic_partition TAD concept

Signed-off-by: raver119 <raver119@gmail.com>

* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix

Signed-off-by: raver119 <raver119@gmail.com>

* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d

Signed-off-by: Yurii <yurii@skymind.io>

* dynamic_stitch CUDA vector case

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case concept

Signed-off-by: raver119 <raver119@gmail.com>

* dynamic_stitch CUDA TAD case impl

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for dynamic_stitch 3D-4D cases.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed type check for dynamic stitch.

* min/max bp

Signed-off-by: raver119 <raver119@gmail.com>

* rewrite code for upsampling2d/3d cpu

Signed-off-by: Yurii <yurii@skymind.io>

* reduce min/max/norm_max bp

Signed-off-by: raver119 <raver119@gmail.com>

* lup implementation. Additional enhancements.

* provide code for upsamling2d/3d backprop

Signed-off-by: Yurii <yurii@skymind.io>

* weightedCrossEntropyWithLogits

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed template math atomicMul for 64bit ints.

* Refactored dynamic_partition_bp op.

* inverseBroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>

* DynamicPartitionBP test datatype fixed.

* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA

Signed-off-by: raver119 <raver119@gmail.com>
master
Alex Black 2019-06-28 01:37:04 +10:00 committed by raver119
parent cae4fc9760
commit 1170827c18
331 changed files with 17959 additions and 7363 deletions

View File

@ -31,7 +31,7 @@ public class TaskCreatorProvider {
}
return c.newInstance();
} catch (Exception e){
throw new RuntimeException("Could not create new instance of task creator class: " + c, e);
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
}
}

View File

@ -83,7 +83,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
}
}
}

View File

@ -79,7 +79,7 @@ public class DataSetIteratorFactoryProvider implements DataProvider {
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
return clazz.newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
}
}
}

View File

@ -54,7 +54,7 @@ public abstract class BaseNetScoreFunction implements ScoreFunction {
ds.configure(dataSourceProperties);
}
} catch (Exception e){
throw new RuntimeException(e);
throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e);
}
return score(model, ds.testData());
}

View File

@ -188,10 +188,15 @@ public class ComputationGraphTaskCreator implements TaskCreator {
//For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both
MultiDataSetIterator iterator;
if(dataSource != null){
DataSource dsInstance = dataSource.newInstance();
if(dataSourceProperties != null)
dsInstance.configure(dataSourceProperties);
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
try {
DataSource dsInstance = dataSource.newInstance();
if (dataSourceProperties != null)
dsInstance.configure(dataSourceProperties);
iterator = ScoreUtil.getMultiIterator(dsInstance.trainData());
} catch (Exception e){
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
" - no zero-arg constructor?",e);
}
} else {
iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters()));
}

View File

@ -190,7 +190,8 @@ public class MultiLayerNetworkTaskCreator implements TaskCreator {
try{
dsInstance = dataSource.newInstance();
} catch (Exception e){
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName());
throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() +
" - no zero-arg constructor?",e);
}
if(dataSourceProperties != null)
dsInstance.configure(dataSourceProperties);

View File

@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
@ -78,14 +79,14 @@ public class TestNDArrayWritableTransforms {
assertEquals(expColNames, tp.getFinalSchema().getColumnNames());
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)));
List<Writable> in = Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)));
List<Writable> out = tp.execute(in);
List<Writable> exp =
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0)),
new NDArrayWritable(Nd4j.linspace(0, 9, 10).addi(2.0)));
Arrays.<Writable>asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)),
new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)),
new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE, 0, 10, 1).addi(2.0).reshape(1,10)));
assertEquals(exp, out);
}

View File

@ -20,9 +20,15 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
public class DataSetSplitterTests extends BaseDL4JTest {
@Test
@ -39,7 +45,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int gcntTest = 0;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++){
for (int e = 0; e < numEpochs; e++) {
int cnt = 0;
while (train.hasNext()) {
val data = train.next().getFeatures();
@ -79,7 +85,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int gcntTest = 0;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++){
for (int e = 0; e < numEpochs; e++) {
int cnt = 0;
while (train.hasNext()) {
val data = train.next().getFeatures();
@ -117,7 +123,7 @@ public class DataSetSplitterTests extends BaseDL4JTest {
int gcntTest = 0;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++){
for (int e = 0; e < numEpochs; e++) {
int cnt = 0;
while (train.hasNext()) {
val data = train.next().getFeatures();
@ -144,4 +150,245 @@ public class DataSetSplitterTests extends BaseDL4JTest {
assertEquals(1000 * numEpochs, global);
}
@Test
public void testSplitter_4() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, 1000, new double[]{0.5, 0.3, 0.2});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
int cnt = 0;
partIterator.reset();
while (partIterator.hasNext()) {
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000* numEpochs, global);
}
@Test
public void testSplitter_5() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{900, 100});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
partIterator.reset();
while (partIterator.hasNext()) {
int cnt = 0;
val data = partIterator.next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data.getFloat(0), 1e-5);
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000 * numEpochs, global);
}
@Test
public void testSplitter_6() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new DataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f);
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testUnorderedSplitter_1() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{500, 500});
List<DataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
// Get data from second part, then rewind for the first one.
int cnt = 0;
int partNumber = 1;
while (iteratorList.get(partNumber).hasNext()) {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data.getFloat(0), 1e-5);
cnt++;
global++;
}
iteratorList.get(partNumber).reset();
partNumber = 0;
cnt = 0;
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++, data.getFloat(0), 1e-5);
global++;
}
}
}
@Test
public void testUnorderedSplitter_2() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{2});
List<DataSetIterator> iteratorList = splitter.getIterators();
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_3() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new DataSetIteratorSplitter(back, new int[]{10});
List<DataSetIterator> iteratorList = splitter.getIterators();
Random random = new Random();
int[] indexes = new int[iteratorList.size()];
for (int i = 0; i < indexes.length; ++i) {
indexes[i] = random.nextInt(iteratorList.size());
}
for (int partNumber : indexes) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
assertEquals("Train failed on iteration " + cnt, (float) (500*partNumber + cnt), data.getFloat(0), 1e-5);
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_4() {
val back = new DataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new DataSetIteratorSplitter(back, new int[]{80, 10, 5});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0); // 0..79
val testIter = splitter.getIterators().get(1); // 80 ..89
val validationIter = splitter.getIterators().get(2); // 90..94
// we're skipping train/test and go for validation first. we're that crazy, right.
int valCnt = 0;
while (validationIter.hasNext()) {
val ds = validationIter.next();
assertNotNull(ds);
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90, ds.getFeatures().getFloat(0), 1e-5);
valCnt++;
}
assertEquals(5, valCnt);
}
}

View File

@ -18,11 +18,17 @@ package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.tools.DataSetGenerator;
import org.deeplearning4j.datasets.iterator.tools.MultiDataSetGenerator;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import static org.junit.Assert.assertEquals;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
/**
*
@ -150,4 +156,309 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest {
assertEquals(1000 * numEpochs, global);
}
@Test
public void testMultiSplitter_1() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testSplitter_5() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{900, 100});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
int iterNo = 0;
int perEpoch = 0;
for (val partIterator : iteratorList) {
partIterator.reset();
while (partIterator.hasNext()) {
int cnt = 0;
val data = partIterator.next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e,
(float) perEpoch, data[i].getFloat(0), 1e-5);
}
//gcntTrain++;
global++;
cnt++;
++perEpoch;
}
++iterNo;
}
}
assertEquals(1000 * numEpochs, global);
}
@Test
public void testSplitter_6() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{800, 100, 100});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0);
val testIter = splitter.getIterators().get(1);
val validationIter = splitter.getIterators().get(2);
// we're going to have multiple epochs
int numEpochs = 10;
for (int e = 0; e < numEpochs; e++) {
int globalIter = 0;
trainIter.reset();
testIter.reset();
validationIter.reset();
boolean trained = false;
while (trainIter.hasNext()) {
trained = true;
val ds = trainIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", trained);
assertEquals(800, globalIter);
// test set is used every epoch
boolean tested = false;
//testIter.reset();
while (testIter.hasNext()) {
tested = true;
val ds = testIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", tested);
assertEquals(900, globalIter);
// validation set is used every 5 epochs
if (e % 5 == 0) {
boolean validated = false;
//validationIter.reset();
while (validationIter.hasNext()) {
validated = true;
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Failed at iteration [" + globalIter + "]", (double) globalIter,
ds.getFeatures()[i].getDouble(0), 1e-5f);
}
globalIter++;
}
assertTrue("Failed at epoch [" + e + "]", validated);
}
// all 3 iterators have exactly 1000 elements combined
if (e % 5 == 0)
assertEquals(1000, globalIter);
else
assertEquals(900, globalIter);
trainIter.reset();
}
}
@Test
public void testUnorderedSplitter_1() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{500, 500});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
val numEpochs = 10;
int global = 0;
// emulating epochs here
for (int e = 0; e < numEpochs; e++) {
// Get data from second part, then rewind for the first one.
int cnt = 0;
int partNumber = 1;
while (iteratorList.get(partNumber).hasNext()) {
int farCnt = (1000 / 2) * (partNumber) + cnt;
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) farCnt, data[i].getFloat(0), 1e-5);
}
cnt++;
global++;
}
iteratorList.get(partNumber).reset();
partNumber = 0;
cnt = 0;
while (iteratorList.get(0).hasNext()) {
val data = iteratorList.get(0).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt + "; epoch: " + e, (float) cnt++,
data[i].getFloat(0), 1e-5);
}
global++;
}
}
}
@Test
public void testUnorderedSplitter_2() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{2});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
for (int partNumber = 0 ; partNumber < iteratorList.size(); ++partNumber) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt), data[i].getFloat(0), 1e-5);
}
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_3() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{10});
List<MultiDataSetIterator> iteratorList = splitter.getIterators();
Random random = new Random();
int[] indexes = new int[iteratorList.size()];
for (int i = 0; i < indexes.length; ++i) {
indexes[i] = random.nextInt(iteratorList.size());
}
for (int partNumber : indexes) {
int cnt = 0;
while (iteratorList.get(partNumber).hasNext()) {
val data = iteratorList.get(partNumber).next().getFeatures();
for (int i = 0; i < data.length; ++i) {
assertEquals("Train failed on iteration " + cnt, (float) (500 * partNumber + cnt),
data[i].getFloat(0), 1e-5);
}
cnt++;
}
}
}
@Test
public void testUnorderedSplitter_4() {
val back = new MultiDataSetGenerator(1000, new int[]{32, 100}, new int[]{32, 5});
// we're going to mimic train+test+validation split
val splitter = new MultiDataSetIteratorSplitter(back, new int[]{80, 10, 5});
assertEquals(3, splitter.getIterators().size());
val trainIter = splitter.getIterators().get(0); // 0..79
val testIter = splitter.getIterators().get(1); // 80 ..89
val validationIter = splitter.getIterators().get(2); // 90..94
// we're skipping train/test and go for validation first. we're that crazy, right.
int valCnt = 0;
while (validationIter.hasNext()) {
val ds = validationIter.next();
assertNotNull(ds);
for (int i = 0; i < ds.getFeatures().length; ++i) {
assertEquals("Validation failed on iteration " + valCnt, (float) valCnt + 90,
ds.getFeatures()[i].getFloat(0), 1e-5);
}
valCnt++;
}
assertEquals(5, valCnt);
}
}

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.dropout.TestDropout;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -196,4 +197,43 @@ public class TestRnnLayers extends BaseDL4JTest {
}
}
@Test
public void testMismatchedInputLabelLength(){
for( int i=0; i<2; i++ ){
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
switch (i){
case 0:
lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build());
break;
case 1:
lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build());
break;
default:
throw new RuntimeException();
}
MultiLayerConfiguration conf = lb.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
try{
net.fit(in,l);
} catch (Throwable t){
String msg = t.getMessage();
assertTrue(msg, msg.contains("sequence length") && msg.contains("input") && msg.contains("label"));
}
}
}
}

View File

@ -249,7 +249,6 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
}
@Test
@Ignore("AB 2019/05/31 - Failing on CI and locally - see issues 7820 and 7657")
public void testCorrectness1() {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123);
@ -270,30 +269,18 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
.useAdaGrad(false).build();
b.fit(data);
System.out.println(b.getData());
/*double[] expectedData = new double[]{15.5392794313924, 19.25226403656672, -5.194955746137196, -31.787679714614757, 48.8674725273665,
24.92775755686273, -22.621939920239065, -29.790772278125395, 19.027362415188914, -16.013800175884274,
-27.454680593309185, 1.2929960811295493, -40.45000061571038, 61.23261682914338, 5.62278768938746,
-28.16665244970911, -20.05502814088798, 12.803274346870865, -24.877262522905497, 45.115883138175874,
21.597495694710616, 18.63254779638783, -4.029728632528419, -0.4596087279592638, -42.35340705500429,
-69.24727547461491, 40.94332685199673, -24.60866142208024, 17.689874972878723, -3.6779759693605314,
-30.91803590368529, 10.645452930824145, 36.58583235020565, -64.74975614289316, -39.364099390585956,
72.54886481127016, -35.30663155696714, 19.37116912936714, -7.790876543092118, 19.6586396288508,
58.1332709511154, -18.49217368496203, -3.5050200971182424, 5.662891294031322, 39.69533295638775,
-15.114610550011662, -32.42366951357609, 17.039297537056537, 42.25610885633673, -2.7013781552769904,
-16.338582630617925, 41.734027526336874, 20.941332646863426, -3.2145240561108244, -45.36033539684912};*/
double[] expectedData = {40.93810899235225, 50.90183660191448, -14.298857560948981, -86.2012232604988, 129.51281793466023,
66.29136854264247, -61.650213611972326, -80.42836756633497, 50.28325210727952, -44.29008119040566,
-74.82748570869279, 2.0170536250746807, -109.21462846594635, 162.3973196127918, 14.000621153511705,
-76.30892822919527, -54.251704596942275, 33.99763310539589, -67.6307009607032, 119.50868525237786,
57.17786598853867, 49.1489174572297, -11.25663463504983, -2.38899196609398, -114.27194947404686,
-185.93832011474473, 108.9022579845252, -66.14099037301474, 47.13683038425694, -10.037893631405792,
-83.88458799629637, 26.985651418254996, 96.68139337135332, -174.2832443285551, -106.0999118697521,
193.02622700008175, -94.88003359113081, 51.39502524568139, -20.96021960048648, 52.32291574424741,
154.33973608321477, -50.90644802585217, -10.345744416395354, 13.721222143380892, 105.2111073677489,
-41.339268919407345, -87.73042354938127, 45.306865238870046, 112.53877133856602, -8.44454352074299,
-44.660828600669056, 110.72662022978719, 55.74660833987147, -9.613556053471232, -122.19953914048916};
double[] expectedData = new double[]{ 63.8206, 80.4013, -19.4424, -140.4326, 198.7239,
106.1148, -96.6273, -124.3634, 78.4174, -83.6621,
-121.8706, 3.0888, -172.8560, 255.1262, 20.7021,
-120.7942, -78.1829, 56.6021, -112.3294, 185.4084,
88.5330, 78.0497, -18.8673, -11.0155, -175.1564,
-297.8463, 174.2511, -103.8793, 72.5455, -15.8498,
-134.5235, 42.3300, 154.0391, -280.1010, -167.9765,
306.9938, -150.9666, 83.4419, -36.0877, 83.9992,
245.1813, -81.5018, -14.8430, 16.1557, 166.8651,
-65.9247, -138.1783, 72.5444, 176.3088, -25.6732,
-69.6843, 167.3360, 87.6238, -18.5874, -187.3806};
INDArray expectedArray = Nd4j.createFromArray(expectedData).reshape(11,5);
for (int i = 0; i < expectedArray.rows(); ++i)

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.util;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -30,7 +31,7 @@ public class TimeSeriesUtilsTest extends BaseDL4JTest {
@Test
public void testMovingAverage() {
INDArray a = Nd4j.arange(0, 20);
INDArray a = Nd4j.arange(0, 20).castTo(DataType.DOUBLE);
INDArray result = Nd4j.create(new double[] {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f,
12.5f, 13.5f, 14.5f, 15.5f, 16.5f, 17.5f});

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
@ -42,14 +43,20 @@ public class DataSetIteratorSplitter {
protected DataSetIterator backedIterator;
protected final long totalExamples;
protected final double ratio;
protected final double[] ratios;
protected final long numTrain;
protected final long numTest;
protected final long numArbitrarySets;
protected final int[] splits;
protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null;
protected int partNumber = 0;
/**
* The only constructor
*
@ -71,17 +78,94 @@ public class DataSetIteratorSplitter {
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = ratio;
this.ratios = null;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.numArbitrarySets = 2;
this.splits = null;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, long totalBatches, double[] ratios) {
for (double ratio : ratios) {
if (!(ratio > 0.0 && ratio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
}
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.ratios = ratios;
this.numTrain = 0; //(long) (totalExamples * ratio);
this.numTest = 0; //totalExamples - numTrain;
this.numArbitrarySets = ratios.length;
this.splits = new int[this.ratios.length];
for (int i = 0; i < this.splits.length; ++i) {
this.splits[i] = (int)(totalExamples * ratios[i]);
}
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public DataSetIteratorSplitter(@NonNull DataSetIterator baseIterator, int[] splits) {
/*if (!(simpleRatio > 0.0 && simpleRatio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");*/
int totalBatches = 0;
for (val v:splits)
totalBatches += v;
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.ratios = null;
this.numTrain = 0; //(long) (totalExamples * ratio);
this.numTest = 0; //totalExamples - numTrain;
this.splits = splits;
this.numArbitrarySets = splits.length;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public List<DataSetIterator> getIterators() {
List<DataSetIterator> retVal = new ArrayList<>();
int partN = 0;
int bottom = 0;
for (final int split : splits) {
ScrollableDataSetIterator partIterator =
new ScrollableDataSetIterator(partN++, backedIterator, counter, resetPending, firstTrain,
new int[]{bottom,split});
bottom += split;
retVal.add(partIterator);
}
return retVal;
}
/**
* This method returns train iterator instance
*
* @return
*/
@Deprecated
public DataSetIterator getTrainIterator() {
return new DataSetIterator() {
@Override
@ -184,6 +268,7 @@ public class DataSetIteratorSplitter {
*
* @return
*/
@Deprecated
public DataSetIterator getTestIterator() {
return new DataSetIterator() {
@Override

View File

@ -21,9 +21,12 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
@ -43,6 +46,9 @@ public class MultiDataSetIteratorSplitter {
protected final double ratio;
protected final long numTrain;
protected final long numTest;
protected final double[] ratios;
protected final long numArbitrarySets;
protected final int[] splits;
protected AtomicLong counter = new AtomicLong(0);
@ -71,15 +77,87 @@ public class MultiDataSetIteratorSplitter {
this.ratio = ratio;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = 0;
this.splits = null;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, long totalBatches, double[] ratios) {
for (double ratio : ratios) {
if (!(ratio > 0.0 && ratio < 1.0))
throw new ND4JIllegalStateException("Ratio value should be in range of 0.0 > X < 1.0");
}
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = ratios.length;
this.splits = new int[this.ratios.length];
for (int i = 0; i < this.splits.length; ++i) {
this.splits[i] = (int)(totalExamples * ratios[i]);
}
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public MultiDataSetIteratorSplitter(@NonNull MultiDataSetIterator baseIterator, int[] splits) {
int totalBatches = 0;
for (val v:splits)
totalBatches += v;
if (totalBatches < 0)
throw new ND4JIllegalStateException("totalExamples number should be positive value");
if (!baseIterator.resetSupported())
throw new ND4JIllegalStateException("Underlying iterator doesn't support reset, so it can't be used for runtime-split");
this.backedIterator = baseIterator;
this.totalExamples = totalBatches;
this.ratio = 0.0;
this.numTrain = (long) (totalExamples * ratio);
this.numTest = totalExamples - numTrain;
this.ratios = null;
this.numArbitrarySets = splits.length;
this.splits = splits;
log.warn("IteratorSplitter is used: please ensure you don't use randomization/shuffle in underlying iterator!");
}
public List<MultiDataSetIterator> getIterators() {
List<MultiDataSetIterator> retVal = new ArrayList<>();
int partN = 0;
int bottom = 0;
for (final int split : splits) {
ScrollableMultiDataSetIterator partIterator =
new ScrollableMultiDataSetIterator(partN++, backedIterator, counter, firstTrain,
new int[]{bottom,split});
bottom += split;
retVal.add(partIterator);
}
return retVal;
}
/**
* This method returns train iterator instance
*
* @return
*/
@Deprecated
public MultiDataSetIterator getTrainIterator() {
return new MultiDataSetIterator() {
@Override
@ -162,6 +240,7 @@ public class MultiDataSetIteratorSplitter {
*
* @return
*/
@Deprecated
public MultiDataSetIterator getTestIterator() {
return new MultiDataSetIterator() {
@Override

View File

@ -0,0 +1,158 @@
package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
public class ScrollableDataSetIterator implements DataSetIterator {
private int thisPart = 0;
private int top = 0;
private int bottom = 0;
protected DataSetIterator backedIterator;
protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null;
protected MultiDataSet firstMultiTrain = null;
private double ratio;
private long totalExamples;
private long itemsPerPart;
private long current;
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
AtomicBoolean resetPending, DataSet firstTrain, double ratio,
int totalExamples) {
this.thisPart = num;
this.backedIterator = backedIterator;
this.counter = counter;
this.resetPending = resetPending;
this.firstTrain = firstTrain;
this.ratio = ratio;
this.totalExamples = totalExamples;
this.itemsPerPart = (long)(totalExamples * ratio);
this.current = 0;
}
public ScrollableDataSetIterator(int num, DataSetIterator backedIterator, AtomicLong counter,
AtomicBoolean resetPending, DataSet firstTrain,
int[] itemsPerPart) {
this.thisPart = num;
this.bottom = itemsPerPart[0];
this.top = bottom + itemsPerPart[1];
this.itemsPerPart = top;
this.backedIterator = backedIterator;
this.counter = counter;
//this.resetPending = resetPending;
this.firstTrain = firstTrain;
//this.totalExamples = totalExamples;
this.current = 0;
}
@Override
public DataSet next(int i) {
throw new UnsupportedOperationException();
}
@Override
public List<String> getLabels() {
return backedIterator.getLabels();
}
@Override
public int inputColumns() {
return backedIterator.inputColumns();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
@Override
public int totalOutcomes() {
return backedIterator.totalOutcomes();
}
@Override
public boolean resetSupported() {
return backedIterator.resetSupported();
}
@Override
public boolean asyncSupported() {
return backedIterator.asyncSupported();
}
@Override
public void reset() {
resetPending.set(true);
}
@Override
public int batch() {
return backedIterator.batch();
}
@Override
public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
backedIterator.setPreProcessor(dataSetPreProcessor);
}
@Override
public DataSetPreProcessor getPreProcessor() {
return backedIterator.getPreProcessor();
}
@Override
public boolean hasNext() {
if (resetPending.get()) {
if (resetSupported()) {
backedIterator.reset();
counter.set(0);
current = 0;
resetPending.set(false);
} else
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
}
boolean state = false;
if (current >= top)
return false;
state = backedIterator.hasNext();
if (!state)
return false;
if (state && counter.get() < itemsPerPart)
return true;
else
return false;
}
@Override
public DataSet next() {
counter.incrementAndGet();
if ((current == 0) && (bottom != 0)) {
backedIterator.reset();
long cnt = current;
for (; cnt < bottom; ++cnt) {
if (backedIterator.hasNext())
backedIterator.next();
}
current = cnt+1;
}
else current++;
val p = backedIterator.next();
return p;
}
}

View File

@ -0,0 +1,121 @@
package org.deeplearning4j.datasets.iterator;
import lombok.val;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import javax.naming.OperationNotSupportedException;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
private int thisPart = 0;
private int top = 0;
private int bottom = 0;
protected MultiDataSetIterator backedIterator;
protected AtomicLong counter = new AtomicLong(0);
protected AtomicBoolean resetPending = new AtomicBoolean(false);
protected DataSet firstTrain = null;
protected MultiDataSet firstMultiTrain = null;
private double ratio;
private long totalExamples;
private long itemsPerPart;
private long current;
public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter,
MultiDataSet firstTrain, int[] itemsPerPart) {
this.thisPart = num;
this.bottom = itemsPerPart[0];
this.top = bottom + itemsPerPart[1];
this.itemsPerPart = top;
this.counter = counter;
//this.resetPending = resetPending;
this.firstTrain = null;
this.firstMultiTrain = firstTrain;
//this.totalExamples = totalExamples;
this.current = 0;
this.backedIterator = backedIterator;
this.resetPending = resetPending;
}
@Override
public boolean resetSupported() {
return backedIterator.resetSupported();
}
@Override
public boolean asyncSupported() {
return backedIterator.asyncSupported();
}
@Override
public void reset() {
resetPending.set(true);
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor dataSetPreProcessor) {
backedIterator.setPreProcessor(dataSetPreProcessor);
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException();
}
@Override
public boolean hasNext() {
if (resetPending.get()) {
if (resetSupported()) {
backedIterator.reset();
counter.set(0);
current = 0;
resetPending.set(false);
} else
throw new UnsupportedOperationException("Reset isn't supported by underlying iterator");
}
boolean state = false;
if (current >= top)
return false;
state = backedIterator.hasNext();
if (!state)
return false;
if (state && counter.get() < itemsPerPart)
return true;
else
return false;
}
@Override
public MultiDataSet next() {
counter.incrementAndGet();
if ((current == 0) && (bottom != 0)) {
backedIterator.reset();
long cnt = current;
for (; cnt < bottom; ++cnt) {
if (backedIterator.hasNext())
backedIterator.next();
}
current = cnt+1;
}
else current++;
val p = backedIterator.next();
return p;
}
@Override
public MultiDataSet next(int i) {
throw new UnsupportedOperationException();
}
}

View File

@ -47,6 +47,8 @@ import static org.bytedeco.hdf5.global.hdf5.*;
@Slf4j
public class Hdf5Archive implements Closeable {
public static final int MAX_BUFFER_SIZE_BYTES = (int)Math.pow(2, 28); //256 MB
/**
* HDF5 library is not thread safe - possible to crash if multiple reads etc are performed concurrently
* in multiple threads. This object is used for locking read etc activity using synchronized blocks
@ -338,7 +340,7 @@ public class Hdf5Archive implements Closeable {
private String readAttributeAsJson(Attribute attribute) throws UnsupportedKerasConfigurationException {
synchronized (Hdf5Archive.LOCK_OBJECT) {
VarLenType vl = attribute.getVarLenType();
int bufferSizeMult = 1;
int currBufferLength = 2048;
String s;
/* TODO: find a less hacky way to do this.
* Reading variable length strings (from attributes) is a giant
@ -349,8 +351,8 @@ public class Hdf5Archive implements Closeable {
* buffer and repeat.
*/
while (true) {
byte[] attrBuffer = new byte[bufferSizeMult * 2000];
BytePointer attrPointer = new BytePointer(attrBuffer);
byte[] attrBuffer = new byte[currBufferLength];
BytePointer attrPointer = new BytePointer(currBufferLength);
attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer);
s = new String(attrBuffer);
@ -362,9 +364,11 @@ public class Hdf5Archive implements Closeable {
} catch (IOException e) {
//OK - we don't know how long the buffer needs to be, so we'll try again with larger buffer
}
bufferSizeMult *= 2;
if (bufferSizeMult > 1024) {
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute");
if(currBufferLength == MAX_BUFFER_SIZE_BYTES){
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute: size exceeds " + currBufferLength + " bytes");
} else {
currBufferLength = (int)Math.min(MAX_BUFFER_SIZE_BYTES, currBufferLength * 4L);
}
}
vl.deallocate();

View File

@ -21,6 +21,7 @@ import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.ClusterUtils;
@ -62,12 +63,13 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
private ClusterSet clusterSet;
private List<Point> initialPoints;
private transient ExecutorService exec;
private boolean useKmeansPlusPlus;
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy) {
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
this.clusteringStrategy = clusteringStrategy;
this.exec = MultiThreadUtils.newExecutorService();
this.useKmeansPlusPlus = useKmeansPlusPlus;
}
/**
@ -75,8 +77,8 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
* @param clusteringStrategy
* @return
*/
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy) {
return new BaseClusteringAlgorithm(clusteringStrategy);
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
}
/**
@ -86,7 +88,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
*/
public ClusterSet applyTo(List<Point> points) {
resetState(points);
initClusters();
initClusters(useKmeansPlusPlus);
iterations();
return clusterSet;
}
@ -130,7 +132,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
* Initialize the
* cluster centers at random
*/
protected void initClusters() {
protected void initClusters(boolean kMeansPlusPlus) {
log.info("Generating initial clusters");
List<Point> points = new ArrayList<>(initialPoints);
@ -152,7 +154,10 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
double r = random.nextFloat() * dxs.maxNumber().doubleValue();
double summed = Nd4j.sum(dxs).getDouble(0);
double r = kMeansPlusPlus ? random.nextDouble() * summed:
random.nextFloat() * dxs.maxNumber().doubleValue();
for (int i = 0; i < dxs.length(); i++) {
double distance = dxs.getDouble(i);
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
@ -170,6 +175,7 @@ public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializabl
new IterationInfo(currentIteration, initialClusterSetInfo));
}
protected void applyClusteringStrategy() {
if (!isStrategyApplicableNow())
return;

View File

@ -79,8 +79,8 @@ public class ClusterUtils {
int nClusters = clusterSet.getClusterCount();
for (int i = 0; i < nClusters; i++) {
final Cluster cluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() {
public void run() {
//tasks.add(new Runnable() {
// public void run() {
try {
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
refreshClusterCenter(cluster, clusterInfo);
@ -88,10 +88,10 @@ public class ClusterUtils {
} catch (Throwable t) {
log.warn("Error refreshing cluster centers", t);
}
}
});
// }
//});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
//MultiThreadUtils.parallelTasks(tasks, executorService);
}
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
@ -146,28 +146,29 @@ public class ClusterUtils {
List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < pointsCount; i++) {
final int i2 = i;
tasks.add(new Runnable() {
public void run() {
//tasks.add(new Runnable() {
// public void run() {
try {
Point point = points.get(i2);
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
: Math.pow(newCluster.getDistanceToCenter(point), 2);
dxs.putScalar(i2, clusterSet.isInverse() ? dist : dist);
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
} catch (Throwable t) {
log.warn("Error computing squared distance from nearest cluster", t);
}
}
});
// }
//});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
//MultiThreadUtils.parallelTasks(tasks, executorService);
for (int i = 0; i < pointsCount; i++) {
double previousMinDistance = previousDxs.getDouble(i);
if (clusterSet.isInverse()) {
if (dxs.getDouble(i) < previousMinDistance)
if (dxs.getDouble(i) < previousMinDistance) {
dxs.putScalar(i, previousMinDistance);
}
} else if (dxs.getDouble(i) > previousMinDistance)
dxs.putScalar(i, previousMinDistance);
}
@ -175,6 +176,23 @@ public class ClusterUtils {
return dxs;
}
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
final List<Point> points, INDArray previousDxs) {
final int pointsCount = points.size();
final INDArray dxs = Nd4j.create(pointsCount);
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
Double sum = new Double(0);
for (int i = 0; i < pointsCount; i++) {
Point point = points.get(i);
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
sum += dist;
dxs.putScalar(i, sum);
}
return dxs;
}
/**
*
* @param clusterSet
@ -194,27 +212,27 @@ public class ClusterUtils {
List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) {
final Cluster cluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() {
public void run() {
//tasks.add(new Runnable() {
// public void run() {
try {
info.getClustersInfos().put(cluster.getId(),
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
} catch (Throwable t) {
log.warn("Error computing cluster set info", t);
}
}
});
//}
//});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
//MultiThreadUtils.parallelTasks(tasks, executorService);
tasks = new ArrayList<>();
//tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) {
final int clusterIdx = i;
final Cluster fromCluster = clusterSet.getClusters().get(i);
tasks.add(new Runnable() {
public void run() {
//tasks.add(new Runnable() {
//public void run() {
try {
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
Cluster toCluster = clusterSet.getClusters().get(k);
@ -230,12 +248,12 @@ public class ClusterUtils {
} catch (Throwable t) {
log.warn("Error computing distances", t);
}
}
});
// }
//});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
//MultiThreadUtils.parallelTasks(tasks, executorService);
return info;
}

View File

@ -37,8 +37,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
*
* @param clusteringStrategy
*/
protected KMeansClustering(ClusteringStrategy clusteringStrategy) {
super(clusteringStrategy);
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
super(clusteringStrategy, useKMeansPlusPlus);
}
/**
@ -50,11 +50,11 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return
*/
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
boolean inverse) {
boolean inverse, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy =
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
return new KMeansClustering(clusteringStrategy);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
/**
@ -66,10 +66,10 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return
*/
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean inverse, boolean allowEmptyClusters) {
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
@ -81,8 +81,8 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @param distanceFunction the distance function to use for grouping
* @return
*/
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction) {
return setup(clusterCount, maxIterationCount, distanceFunction, false);
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
}
/**
@ -94,17 +94,17 @@ public class KMeansClustering extends BaseClusteringAlgorithm {
* @return
*/
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean allowEmptyClusters) {
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
boolean allowEmptyClusters) {
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
return new KMeansClustering(clusteringStrategy);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
}

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.clustering.kmeans;
import lombok.val;
import org.apache.commons.lang3.time.StopWatch;
import org.deeplearning4j.clustering.BaseDL4JTest;
import org.deeplearning4j.clustering.algorithm.Distance;
@ -28,22 +29,25 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.Assert.*;
/**
* Created by agibsonccc on 7/2/17.
*/
public class KMeansTest extends BaseDL4JTest {
private boolean[] useKMeansPlusPlus = {true, false};
@Test
public void testKMeans() {
Nd4j.getRandom().setSeed(7);
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN);
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
System.out.println(pointClassification);
for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode);
List<Point> points = Point.toPoints(Nd4j.randn(5, 5));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
System.out.println(pointClassification);
}
}
@Test
@ -51,20 +55,22 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.getRandom().setSeed(7);
int numClusters = 5;
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
List<Point> points = Point.toPoints(Nd4j.rand(5, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
PointClassification pointClassification = clusterSet.classifyPoint(points.get(0));
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN);
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
System.out.println("Cosine " + pointClassification);
System.out.println("Euclidean " + pointClassificationEuclidean);
KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points);
PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0));
System.out.println("Cosine " + pointClassification);
System.out.println("Euclidean " + pointClassificationEuclidean);
assertEquals(pointClassification.getCluster().getPoints().get(0),
pointClassificationEuclidean.getCluster().getPoints().get(0));
assertEquals(pointClassification.getCluster().getPoints().get(0),
pointClassificationEuclidean.getCluster().getPoints().get(0));
}
}
@Ignore
@ -73,22 +79,24 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7);
int numClusters = 20;
StopWatch watch = new StopWatch();
watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, true);
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000*300, 5000*300).reshape(5000,300 ));
for (boolean mode : useKMeansPlusPlus) {
StopWatch watch = new StopWatch();
watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode);
List<Point> points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
ClusterSet clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
}
@Test
@ -97,41 +105,43 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7);
int numClusters = 20;
StopWatch watch = new StopWatch();
watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false);
for (boolean mode : useKMeansPlusPlus) {
StopWatch watch = new StopWatch();
watch.start();
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode);
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
List<Point> points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
ClusterSet clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
watch.reset();
watch.start();
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode);
points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300));
clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
watch.reset();
watch.start();
kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false);
points = Point.toPoints(Nd4j.linspace(0, 10000*300, 10000*300).reshape(10000,300 ));
clusterSet = kMeansClustering.applyTo(points);
watch.stop();
System.out.println("Elapsed for clustering : " + watch);
watch.reset();
watch.start();
for (Point p : points) {
PointClassification pointClassification = clusterSet.classifyPoint(p);
}
watch.stop();
System.out.println("Elapsed for search: " + watch);
}
@Test
@ -141,45 +151,47 @@ public class KMeansTest extends BaseDL4JTest {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7);
int numClusters = 3;
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, true);
double[] data = new double[]{
15, 16,
16, 18.5,
17, 20.2,
16.4, 17.12,
17.23, 18.12,
43, 43,
44.43, 45.212,
45.8, 54.23,
46.313, 43.123,
50.21, 46.3,
99, 99.22,
100.32, 98.123,
100.32, 97.423,
102, 93.23,
102.23, 94.23
};
List<Point> points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2));
for (boolean mode : useKMeansPlusPlus) {
KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode);
double[] data = new double[]{
15, 16,
16, 18.5,
17, 20.2,
16.4, 17.12,
17.23, 18.12,
43, 43,
44.43, 45.212,
45.8, 54.23,
46.313, 43.123,
50.21, 46.3,
99, 99.22,
100.32, 98.123,
100.32, 97.423,
102, 93.23,
102.23, 94.23
};
List<Point> points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2));
ClusterSet clusterSet = kMeansClustering.applyTo(points);
ClusterSet clusterSet = kMeansClustering.applyTo(points);
INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850});
INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500});
INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990});
INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850});
INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500});
INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990});
/*List<Cluster> clusters = clusterSet.getClusters();
assertEquals(row0, clusters.get(0).getCenter().getArray());
assertEquals(row1, clusters.get(1).getCenter().getArray());
assertEquals(row2, clusters.get(2).getCenter().getArray());*/
PointClassification pointClassification = null;
for (Point p : points) {
pointClassification = clusterSet.classifyPoint(p);
System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray());
List<Cluster> clusters = clusterSet.getClusters();
for (int i = 0; i < clusters.size(); ++i)
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
PointClassification pointClassification = null;
for (Point p : points) {
pointClassification = clusterSet.classifyPoint(p);
System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray());
List<Cluster> clusters = clusterSet.getClusters();
for (int i = 0; i < clusters.size(); ++i)
System.out.println("Choice: " + clusters.get(i).getCenter().getArray());
}
}
/*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}),
pointClassification.getCluster().getCenter().getArray());*/
@ -233,4 +245,39 @@ public class KMeansTest extends BaseDL4JTest {
System.out.println();
}
}
@Test
public void testInitClusters() {
Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE);
Nd4j.getRandom().setSeed(7);
{
KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true);
double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8},
{8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8},
{1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8},
{2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8},
{3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8},
{4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8},
{4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8},
{5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8},
{6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}};
INDArray data = Nd4j.createFromArray(dataArray);
List<Point> points = Point.toPoints(data);
ClusterSet clusterSet = kMeansClustering.applyTo(points);
double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8};
double[] centroid2 = {5.14e8, 5.41e8, 5.68e8, 5.95e8};
double[] centroid3 = {1.63e8, 1.9e8, 2.17e8, 2.44e8};
double[] centroid4 = {6.76e8, 7.03e8, 7.3e8, 7.57e8};
double[] centroid5 = {4.06e8, 4.33e8, 4.6e8, 4.87e8};
assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4);
assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4);
}
}
}

View File

@ -23,6 +23,8 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
@ -857,4 +859,34 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
}
}
@Test
public void testBackwardsCompatibleWord2Vec() {
File model_v3 = Resources.asFile("deeplearning4j-nlp/model_beta3.zip");
File model_v4 = Resources.asFile("deeplearning4j-nlp/model_beta4.zip");
Word2Vec word2Vec1 = WordVectorSerializer.readWord2VecModel(model_v3, true);
Word2Vec word2Vec2 = WordVectorSerializer.readWord2VecModel(model_v4, true);
try {
assertEquals(word2Vec1.toJson(), word2Vec2.toJson());
} catch (Exception e) {
fail(e.getMessage());
}
}
@Test
public void testBackwardsCompatibleSequenceVectors() {
File model_v3 = Resources.asFile("deeplearning4j-nlp/seqv_beta3.csv");
File model_v4 = Resources.asFile("deeplearning4j-nlp/seqv_beta4.csv");
try {
SequenceVectors vectors1 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v3);
SequenceVectors vectors2 = WordVectorSerializer.readSequenceVectors(new VocabWordFactory(), model_v4);
assertEquals(vectors1.vocab().numWords(), vectors2.vocab().numWords());
for (int i = 0; i < vectors1.vocab().numWords(); ++i) {
assertEquals(vectors1.vocab().words().toArray()[i], vectors2.vocab().words().toArray()[i]);
}
} catch (Exception e) {
fail(e.getMessage());
}
}
}

View File

@ -249,7 +249,7 @@ public class BertIterator implements MultiDataSetIterator {
} else {
throw new RuntimeException();
}
l[0] = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, numClasses);
l[0] = Nd4j.create(DataType.FLOAT, mbPadded, numClasses);
for( int i=0; i<mb; i++ ){
l[0].putScalar(i, classLabels[i], 1.0);
}
@ -277,9 +277,9 @@ public class BertIterator implements MultiDataSetIterator {
if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX){
labelArr = Nd4j.create(DataType.INT, mbPadded, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL){
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), mbPadded, vocabSize, outLength);
labelArr = Nd4j.create(DataType.FLOAT, mbPadded, vocabSize, outLength);
} else if(unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC){
labelArr = Nd4j.create(Nd4j.defaultFloatingPointType(), outLength, mbPadded, vocabSize);
labelArr = Nd4j.create(DataType.FLOAT, outLength, mbPadded, vocabSize);
} else {
throw new IllegalStateException("Unknown unsupervised label format: " + unsupervisedLabelFormat);
}

View File

@ -201,7 +201,7 @@ public class CnnSentenceDataSetIterator implements DataSetIterator {
List<String> tokens = new ArrayList<>();
while (t.hasMoreTokens()) {
String token = t.nextToken();
if (!wordVectors.hasWord(token)) {
if (!wordVectors.outOfVocabularySupported() && !wordVectors.hasWord(token)) {
switch (unknownWordHandling) {
case RemoveWord:
continue;

View File

@ -1312,10 +1312,12 @@ public class SequenceVectors<T extends SequenceElement> extends WordVectorsImpl<
int rest = batchSequences.size() % batchSize;
int chunks = ((batchSequences.size() >= batchSize) ? batchSequences.size() / batchSize : 0) + ((rest > 0)? 1 : 0);
for (int j = 0; j < chunks; ++j) {
if (elementsLearningAlgorithm instanceof SkipGram)
((SkipGram)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
else if (elementsLearningAlgorithm instanceof CBOW)
((CBOW)elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
if (trainElementsVectors) {
if (elementsLearningAlgorithm instanceof SkipGram)
((SkipGram) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
else if (elementsLearningAlgorithm instanceof CBOW)
((CBOW) elementsLearningAlgorithm).iterateSample(batchSequences.get(j));
}
if (trainSequenceVectors) {
if (sequenceLearningAlgorithm instanceof DBOW)

View File

@ -32,7 +32,7 @@ import java.io.Serializable;
*
* @author Adam Gibson
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class", defaultImpl = VocabWord.class)
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
setterVisibility = JsonAutoDetect.Visibility.NONE)
public class VocabWord extends SequenceElement implements Serializable {

View File

@ -224,6 +224,7 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L)
public void testMinibatchPadding() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.api;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
@ -73,4 +74,6 @@ public interface TrainingConfig {
*/
double getGradientNormalizationThreshold();
void setDataType(DataType dataType);
}

View File

@ -93,4 +93,9 @@ public abstract class GraphVertex implements Cloneable, Serializable {
*/
public abstract MemoryReport getMemoryReport(InputType... inputTypes);
public void setDataType(DataType dataType) {
//No-op for most layers
}
}

View File

@ -146,4 +146,9 @@ public class LayerVertex extends GraphVertex {
//TODO preprocessor memory
return layerConf.getLayer().getMemoryReport(it);
}
@Override
public void setDataType(DataType dataType){
layerConf.getLayer().setDataType(dataType);
}
}

View File

@ -223,6 +223,11 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable {
"Not supported: all layers with parameters should override this method");
}
@Override
public void setDataType(DataType dataType) {
//No-op for most layers
}
/**
* This is a report of the estimated memory consumption for the given layer
*

View File

@ -96,7 +96,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
if (!map.containsKey(inputNum)) {
//Lazily define extra input variable as required
SDVariable var = sameDiff.var("var_" + inputNum, 1); //TODO is this shape safe?
SDVariable var = sameDiff.var("var_" + inputNum, dataType, -1); //TODO is this shape safe?
map.put(inputNum, var);
}

View File

@ -62,6 +62,7 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
protected IUpdater biasUpdater;
protected GradientNormalization gradientNormalization;
protected double gradientNormalizationThreshold = Double.NaN;
protected DataType dataType;
/**
* Define the vertex
@ -234,4 +235,9 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf
public double getGradientNormalizationThreshold() {
return gradientNormalizationThreshold;
}
@Override
public void setDataType(DataType dataType) {
this.dataType = dataType;
}
}

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.misc;
import lombok.AllArgsConstructor;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.regularization.Regularization;
@ -63,4 +64,9 @@ public class DummyConfig implements TrainingConfig {
public double getGradientNormalizationThreshold() {
return 1.0;
}
@Override
public void setDataType(DataType dataType) {
}
}

View File

@ -512,6 +512,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
for(; i<topologicalOrder.length; i++ ){
String name = indices.getIdxToName().get(i);
org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name);
n.setDataType(netDtype);
numParamsForVertex[i] = n.numParams(true);
numParams += numParamsForVertex[i];
}

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -35,6 +36,7 @@ import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import java.util.Arrays;
import java.util.List;
/**
@ -60,10 +62,16 @@ public class RnnLossLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.Rn
assertInputSet(true);
if (input.rank() != 3)
throw new UnsupportedOperationException(
"Input is not rank 3. Got input with rank " + input.rank() + " " + layerId());
"Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " +
input.rank() + " with shape " + Arrays.toString(input.shape()) + " for layer " + layerId());
if (labels == null)
throw new IllegalStateException("Labels are not set (null)");
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
INDArray input2d = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
INDArray labels2d = TimeSeriesUtils.reshape3dTo2d(labels, workspaceMgr, ArrayType.BP_WORKING_MEM);
INDArray maskReshaped;

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -57,8 +58,13 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
}
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(2) == labels.size(2), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
INDArray inputTemp = input;
this.input = TimeSeriesUtils.reshape3dTo2d(input, workspaceMgr, ArrayType.BP_WORKING_MEM);
Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon, workspaceMgr); //Also applies dropout
this.input = inputTemp;
INDArray epsilon2d = gradAndEpsilonNext.getSecond();

View File

@ -39,9 +39,7 @@ import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.*;
/**
* Implementation of a SameDiff graph vertex.
@ -96,12 +94,11 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
@Override
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
if(sameDiff == null){
doInit();
}
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
// sameDiff.clearExecutionCache();
if(sameDiff == null){
doInit();
}
config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
@ -121,6 +118,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
}
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
}
}
@ -131,27 +132,42 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
INDArray[] dLdIns;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
// sameDiff.clearExecutionCache();
if(sameDiff == null){
doInit();
}
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
List<String> inputs = config.getVertexParams().getInputs();
String[] inArr = inputs.toArray(new String[inputs.size()]);
sameDiff.createGradFunction(inArr);
}
config.validateInput(inputs);
//Set inputs
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
final String maskName = name + "_mask";
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
if(maskArrays != null && maskArrays[i] != null) {
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
}else{
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
Map<String,INDArray> phMap = new HashMap<>();
List<String> inputs = config.getVertexParams().getInputs();
int i=0;
for(String s : inputs){
phMap.put(s, this.inputs[i++]);
}
if(maskArrays != null){
for( int j=0; j<maskArrays.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask";
if(maskArrays[j] != null) {
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
}
}
}
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
}
sameDiff.execBackwards(null);
sameDiff.execBackwards(phMap);
for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s);
@ -159,10 +175,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
g.gradientForVariable().put(s, dl4jGrad);
}
dLdIns = new INDArray[inputs.length];
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
dLdIns[i] = sameDiff.grad(name).getArr();
dLdIns = new INDArray[inputs.size()];
for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr();
}
}
@ -171,6 +187,9 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return new Pair<>(g, dLdIns);
}

View File

@ -35,6 +35,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*;
@ -78,25 +79,32 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(false);
if(sameDiff == null){
doInit();
}
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){
doInit();
}
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(maskArray != null){
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
}else{
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
phMap.put(MASK_KEY, maskArray);
}
for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
}
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
}
}
@ -110,24 +118,36 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
// sameDiff.clearExecutionCache();
if(sameDiff == null){
doInit();
}
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY);
}
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
if(maskArray != null){
sameDiff.associateArrayWithVariable(maskArray, sameDiff.getVariable(MASK_KEY));
}else{
sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(dataType, input.shape()), sameDiff.getVariable(MASK_KEY));
}
fn.updateVariable(outputVar.getVarName(), epsilon.dup());
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
}
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
}
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName());
for(String s : paramTable.keySet()){
requiredGrads.add(sameDiff.grad(s).getVarName());
}
sameDiff.execBackwards(phMap, requiredGrads);
for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s);
@ -138,6 +158,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
System.out.println(dLdIn);
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
}
@ -225,8 +250,9 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff = SameDiff.create();
Map<String, INDArray> p = paramTable();
val inputShape = input.shape().clone();
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
long[] inputShape = input.shape().clone();
inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
Map<String, SDVariable> params = new LinkedHashMap<>();
for (String s : paramShapes.keySet()) {
@ -235,7 +261,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
params.put(s, v);
}
SDVariable mask = sameDiff.constant(MASK_KEY, SameDiffGraphVertex.createMask(dataType, inputShape));
long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, -1);
SDVariable mask = sameDiff.placeHolder(MASK_KEY, dataType, maskShape);
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, params, mask);
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");

View File

@ -87,35 +87,43 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr){
assertInputSet(false);
//Check where the output occors. If it's a simple loss layer (no params) this could
//Check where the output occurs. If it's a simple loss layer (no params) this could
// just be the input!
if(activations && INPUT_KEY.equals(layerConf().activationsVertexName())){
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
}
if(sameDiff == null){
doInit();
}
//TODO optimize
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
sameDiff.associateArrayWithVariable(input.dup(), sameDiff.getVariable(INPUT_KEY));
if(layerConf().labelsRequired() && labels != null) {
sameDiff.associateArrayWithVariable(labels.dup(), sameDiff.getVariable(LABELS_KEY));
if(sameDiff == null){
doInit();
}
for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
}
INDArray score = sameDiff.execAndEndResult();
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) {
phMap.put(LABELS_KEY, labels);
}
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
INDArray out = sameDiff.execSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
if(activations) {
INDArray result = sameDiff.getArrForVarName(layerConf().activationsVertexName());
Preconditions.checkNotNull(result, "Activations (result) array for variable \"%s\" was " +
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
"does not exist", layerConf().activationsVertexName());
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
} else {
return score;
return out;
}
}
}
@ -127,23 +135,26 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " +
"If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" +
" to return false instead");
if(sameDiff == null){
//Usually doInit will be called in forward pass; not necessarily the case in output layers
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
doInit();
}
Gradient g = new DefaultGradient();
INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
INDArray castInput = input.castTo(Nd4j.defaultFloatingPointType());
if(sameDiff == null){
//Usually doInit will be called in forward pass; not necessarily the case in output layers
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
doInit();
}
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY);
}
INDArray castInput = input.castTo(dataType);
if(castInput.isAttached())
castInput = castInput.dup();
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
if(layerConf().labelsRequired()) {
INDArray castLabels = labels.castTo(Nd4j.defaultFloatingPointType());
INDArray castLabels = labels.castTo(dataType);
if(castLabels.isAttached())
castLabels = castLabels.dup();
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
@ -154,7 +165,17 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
}
sameDiff.execBackwards(Collections.<String, INDArray>emptyMap());
List<String> gradVarNames = new ArrayList<>();
for(String s : paramTable.keySet()){
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
}
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels);
sameDiff.execBackwards(phMap, gradVarNames);
for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s);
@ -165,6 +186,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
}
@ -252,18 +277,20 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff = SameDiff.create();
Map<String, INDArray> p = paramTable();
val inputShape = input.shape().clone();
SDVariable inputVar = sameDiff.var(INPUT_KEY, dataType, inputShape);
long[] inputShape = input.shape().clone();
inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape);
SDVariable labelVar = null;
if(layerConf().labelsRequired()){
long[] labelShape = labels == null ? new long[]{1} : labels.shape().clone();
labelVar = sameDiff.var(LABELS_KEY, dataType, labelShape);
long[] labelShape = labels == null ? new long[]{-1, -1} : labels.shape().clone();
labelShape[0] = -1;
labelVar = sameDiff.placeHolder(LABELS_KEY, dataType, labelShape);
}
Map<String, long[]> paramShapes = layerConf().getLayerParams().getParamShapes();
Map<String, SDVariable> params = new LinkedHashMap<>();
for (String s : paramShapes.keySet()) {
val ps = paramShapes.get(s);
SDVariable v = sameDiff.var(s, ps);
SDVariable v = sameDiff.var(s, dataType, ps);
params.put(s, v);
}
SDVariable layerOutput = bl.defineLayer(sameDiff, inputVar, labelVar, params);

View File

@ -660,6 +660,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
val nParamsPerLayer = new long[nLayers];
for (int i = 0; i < nLayers; i++) {
NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i);
conf.getLayer().setDataType(netDtype);
nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf);
paramLength += nParamsPerLayer[i];
}

View File

@ -152,7 +152,7 @@ public class HardwareMetric implements Serializable {
return builder.logicalProcessorCount(processor.getLogicalProcessorCount())
.physicalProcessorCount(processor.getPhysicalProcessorCount())
.name(name)
.averagedCpuLoad((long) processor.getSystemCpuLoad() * 100)
.averagedCpuLoad((long)(processor.getSystemCpuLoad() * 100))
.ioWaitTime(iowait).gpuMetrics(gpuMetric)
.hostName(networkParams.getHostName()).diskInfo(diskInfoMap)
.currentMemoryUse(globalMemory.getTotal() - globalMemory.getAvailable())

View File

@ -48,8 +48,6 @@ if(WIN32)
SET(CMAKE_NINJA_FORCE_RESPONSE_FILE 1 CACHE INTERNAL "")
endif()
if ("${LIBND4J_ALL_OPS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DLIBND4J_ALL_OPS=true")
else()
@ -234,21 +232,21 @@ if(CUDA_BLAS)
endif()
endif()
if (NOT BUILD_TESTS)
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/*.cpp ../include/execution/*.h)
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/cuda/*.cu ../include/helpers/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu ../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
@ -258,26 +256,12 @@ if(CUDA_BLAS)
else()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../include/exceptions/*.cpp ../include/exceptions/*.h)
file(GLOB_RECURSE EXEC_SOURCES false ../include/execution/impl/*.cpp ../include/execution/*.cu ../include/execution/*.h)
file(GLOB_RECURSE TYPES_SOURCES false ../include/types/*.cpp ../include/types/*.h)
file(GLOB_RECURSE ARRAY_SOURCES false ../include/array/impl/*.cpp ../include/array/cuda/*.cu ../include/array/*.h)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/impl/*.cpp ../include/memory/cuda/*.cu ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.cu ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cuda/*.cu)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/impl/*.cpp ../include/helpers/*.cu ../include/helpers/*.cupp ../include/helpers/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES})
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES})
endif()
@ -308,7 +292,7 @@ elseif(CPU_BLAS)
file(GLOB_RECURSE MEMORY_SOURCES false ../include/memory/*.cpp ../include/memory/*.h)
file(GLOB_RECURSE GRAPH_SOURCES false ../include/graph/*.cpp ../include/graph/*.h)
file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../include/ops/declarable/generic/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp)
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES false ../include/ops/declarable/helpers/cpu/*.cpp ../include/ops/declarable/helpers/impl/*.cpp)
file(GLOB_RECURSE OPS_SOURCES false ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
file(GLOB_RECURSE INDEXING_SOURCES false ../include/indexing/*.cpp ../include/indexing/*.h)
file(GLOB_RECURSE HELPERS_SOURCES false ../include/helpers/*.cpp ../include/helpers/*.h)

View File

@ -372,8 +372,8 @@ namespace nd4j {
/**
* if _bufferD==nullptr return _buffer, else return _bufferD
*/
FORCEINLINE void* specialBuffer();
FORCEINLINE void* getSpecialBuffer() const;
void* specialBuffer();
void* getSpecialBuffer() const;
/**
* returns device buffer if compilation is for cuda case, otherwise returns host buffer
@ -429,16 +429,16 @@ namespace nd4j {
/**
* permutes the dimensions in array according to "dimensions" array, new array points on _buffer of this array
*/
NDArray* permute(const std::initializer_list<int>& dimensions) const;
NDArray* permute(const std::vector<int>& dimensions) const;
NDArray* permute(const int* dimensions, const int rank) const;
NDArray permute(const std::initializer_list<int>& dimensions) const;
NDArray permute(const std::vector<int>& dimensions) const;
NDArray permute(const int* dimensions, const int rank) const;
void permute(const int* dimensions, const int rank, NDArray& target) const;
void permute(const std::vector<int>& dimensions, NDArray& target) const;
NDArray* permute(const std::initializer_list<Nd4jLong>& dimensions) const;
NDArray* permute(const std::vector<Nd4jLong>& dimensions) const;
NDArray* permute(const Nd4jLong* dimensions, const int rank) const;
NDArray permute(const std::initializer_list<Nd4jLong>& dimensions) const;
NDArray permute(const std::vector<Nd4jLong>& dimensions) const;
NDArray permute(const Nd4jLong* dimensions, const int rank) const;
void permute(const Nd4jLong* dimensions, const int rank, NDArray& target) const;
void permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const;
@ -508,7 +508,7 @@ namespace nd4j {
/**
* returns new copy of this array, optionally in different order
*/
NDArray *dup(const char newOrder = 'a');
NDArray *dup(const char newOrder = 'a') const;
/**
* returns sum of all elements of array
@ -687,7 +687,7 @@ namespace nd4j {
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
#if defined(__CUDABLAS__) && defined(BUILD_TESTS)
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
template <typename Lambda>
FORCEINLINE void applyLambda(Lambda func, NDArray* target = nullptr);
@ -790,8 +790,7 @@ namespace nd4j {
/**
* apply transpose operation to the copy of this array, that is this array remains unaffected
*/
NDArray* transpose() const;
NDArray transp() const;
NDArray transpose() const;
/**
* perform transpose operation and store result in target, this array remains unaffected
@ -915,7 +914,7 @@ namespace nd4j {
*
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/
NDArray* reshape(const char order, const std::vector<Nd4jLong>& shape) const;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const;
/**
* calculate strides and set given order
@ -2093,15 +2092,6 @@ Nd4jLong* NDArray::shapeInfo() {
return _shapeInfo;
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
Nd4jLong* NDArray::specialShapeInfo() {
if (_shapeInfoD == nullptr)
@ -2110,14 +2100,6 @@ Nd4jLong* NDArray::specialShapeInfo() {
return _shapeInfoD;
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
Nd4jLong NDArray::getBufferOffset() const {
return _offset;
@ -2137,7 +2119,7 @@ Nd4jLong* NDArray::getSpecialShapeInfo() const{
}
#if defined(__CUDACC__) && defined(BUILD_TESTS)
#if defined(__CUDACC__) //&& defined(BUILD_TESTS)
// for CUDA we need stil stuff inline
#include "cuda/NDArrayLambda.hpp"
#endif

View File

@ -39,9 +39,9 @@ NDArray* NDArray::asT() const{
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
auto l = this->lengthOf();
prepareSpecialUse({result}, {this});
NDArray::prepareSpecialUse({result}, {this});
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result->getBuffer(), result->getShapeInfo(), result->getSpecialBuffer(), result->getSpecialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({result}, {this});
NDArray::registerSpecialUse({result}, {this});
return result;
}
@ -583,117 +583,130 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop
void NDArray::assign(const double value) {
// just fire scalar
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const float value) {
// just fire scalar
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const float16 value) {
// just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const bfloat16& value) {
// just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const Nd4jLong value) {
// just fire scalar
auto temp = NDArrayFactory::create(value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int16_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint8_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint16_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint32_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const uint64_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const int8_t value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::assign(const bool value) {
// just fire scalar
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
prepareSpecialUse({this}, {&temp});
NDArray::prepareSpecialUse({this}, {&temp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&temp});
NDArray::registerSpecialUse({this}, {&temp});
}
//////////////////////////////////////////////////////////////////////////
@ -716,9 +729,9 @@ NDArray NDArray::varianceNumber(nd4j::variance::Ops op, bool biasCorrected) {
NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext());
prepareSpecialUse({&res}, {this});
NDArray::prepareSpecialUse({&res}, {this});
NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected);
registerSpecialUse({&res}, {this});
NDArray::registerSpecialUse({&res}, {this});
return res;
}
@ -918,9 +931,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::FloatOps op, void *extraParams) cons
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()));
NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -932,9 +945,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::SameOps op, void *extraParams) const
NDArray result(dataType(), getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -947,9 +960,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::BoolOps op, void *extraParams) const
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::BOOL);
NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -962,9 +975,9 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
auto shape = ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64);
NDArray result(shape, true, this->getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo());
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -976,9 +989,9 @@ void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *ext
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
prepareSpecialUse({&target}, {this});
NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceFloatScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this});
NDArray::registerSpecialUse({&target}, {this});
}
//////////////////////////////////////////////////////////////////////////
@ -989,9 +1002,9 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != dataType())
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
prepareSpecialUse({&target}, {this});
NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this});
NDArray::registerSpecialUse({&target}, {this});
}
//////////////////////////////////////////////////////////////////////////
@ -1002,9 +1015,9 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != DataType::BOOL)
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
prepareSpecialUse({&target}, {this});
NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceBoolScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this});
NDArray::registerSpecialUse({&target}, {this});
}
//////////////////////////////////////////////////////////////////////////
@ -1015,9 +1028,9 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
if(!target.isScalar() || target.dataType() != DataType::INT64)
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
prepareSpecialUse({&target}, {this});
NDArray::prepareSpecialUse({&target}, {this});
NativeOpExecutioner::execReduceLongScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams, target.getBuffer(), target.getShapeInfo(), target.specialBuffer(), target.getSpecialShapeInfo());
registerSpecialUse({&target}, {this});
NDArray::registerSpecialUse({&target}, {this});
}
//////////////////////////////////////////////////////////////////////////
@ -1027,9 +1040,9 @@ NDArray NDArray::indexReduceNumber(nd4j::indexreduce::Ops op, ExtraArguments *ex
auto res = NDArrayFactory::create<Nd4jLong>(0);
NDArray::prepareSpecialUse({&res}, {this});
NDArray::NDArray::prepareSpecialUse({&res}, {this});
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo());
NDArray::registerSpecialUse({&res}, {this});
NDArray::NDArray::registerSpecialUse({&res}, {this});
return res;
}
@ -1240,17 +1253,10 @@ BUILD_SINGLE_TEMPLATE(template void* NDArray::templatedPointerShift, (const Nd4j
//////////////////////////////////////////////////////////////////////////
// method makes copy of this array and applies to the copy transpose operation, this array remains unaffected
NDArray* NDArray::transpose() const {
auto newArr = new NDArray(getBuffer(), getSpecialBuffer(), getShapeInfo(), getContext(), false, false);
newArr->transposei();
return newArr;
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::transp() const {
NDArray newArr(getBuffer(), getShapeInfo(), getContext(), false);
NDArray NDArray::transpose() const {
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.transposei();
return newArr;
}
@ -1360,10 +1366,10 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
//////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray* NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const {
auto newArr = new NDArray(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext());
newArr->reshapei(order, shape);
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.reshapei(order, shape);
return newArr;
}
@ -1420,43 +1426,43 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
}
//////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const int* dimensions, const int rank) const {
NDArray NDArray::permute(const int* dimensions, const int rank) const {
// evaluate shapeInfo for output (permuted) array ret
auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace());
auto ret = new NDArray(_buffer, ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
ret->_isView = true;
NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), getBufferOffset());
ret._isView = true;
return ret;
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const {
int tempDims[MAX_RANK];
shape::convertT<Nd4jLong, int>(const_cast<Nd4jLong *>(dimensions), tempDims, rank);
return permute(tempDims, rank);
}
//////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::vector<int>& dimensions) const {
NDArray NDArray::permute(const std::vector<int>& dimensions) const {
auto data = dimensions.data();
auto size = dimensions.size();
return permute(data, size);
}
//////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const {
return permute(dimensions.data(), dimensions.size());
}
//////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::initializer_list<int>& dimensions) const {
NDArray NDArray::permute(const std::initializer_list<int>& dimensions) const {
std::vector<int> vec(dimensions);
return permute(vec);
}
//////////////////////////////////////////////////////////////////////////
NDArray* NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
NDArray NDArray::permute(const std::initializer_list<Nd4jLong>& dimensions) const {
std::vector<Nd4jLong> vec(dimensions);
return permute(vec);
}
@ -1528,10 +1534,9 @@ bool NDArray::isUnitary() {
throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !");
auto tr = this->transpose();
auto trMul = MmulHelper::mmul(this, tr, nullptr, 1.f, 0.f);
auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f);
bool result = trMul->isIdentityMatrix();
delete tr;
delete trMul;
return result;
@ -1777,11 +1782,11 @@ NDArray NDArray::operator*(const T& scalar) const {
auto tmp = NDArrayFactory::create(dataType(), scalar, getContext());
NDArray result(_shapeInfo, DataTypeUtils::pickPairwiseResultType(dataType(), DataTypeUtils::fromT<T>()), false, getContext());
NDArray::prepareSpecialUse({&result}, {this, &tmp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &tmp});
return result;
}
template NDArray NDArray::operator*(const double& scalar) const;
@ -1811,6 +1816,7 @@ NDArray NDArray::operator/(const T& scalar) const {
NDArray::prepareSpecialUse({&result}, {this, &tmp});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &tmp});
return result;
}
template NDArray NDArray::operator/(const double& scalar) const;
@ -2050,14 +2056,14 @@ void NDArray::operator+=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else{
Nd4jLong *bShape = nullptr;
@ -2084,14 +2090,14 @@ void NDArray::operator-=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else{
Nd4jLong *bShape = nullptr;
@ -2117,14 +2123,14 @@ void NDArray::operator*=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else{
Nd4jLong *bShape = nullptr;
@ -2154,14 +2160,14 @@ void NDArray::operator/=(const NDArray& other) {
}
if (!this->isScalar() && other.isScalar()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
prepareSpecialUse({this}, {this, &other});
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {this, &other});
NDArray::registerSpecialUse({this}, {this, &other});
}
else{
Nd4jLong *bShape = nullptr;
@ -2264,9 +2270,9 @@ NDArray NDArray::operator-(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
prepareSpecialUse({&result}, {this, &other});
NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other});
NDArray::registerSpecialUse({&result}, {this, &other});
return result;
}
@ -2285,9 +2291,9 @@ NDArray NDArray::operator*(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
prepareSpecialUse({&result}, {this, &other});
NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other});
NDArray::registerSpecialUse({&result}, {this, &other});
return result;
}
@ -2308,9 +2314,9 @@ NDArray NDArray::operator/(const NDArray& other) const {
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
prepareSpecialUse({&result}, {this, &other});
NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
registerSpecialUse({&result}, {this, &other});
NDArray::registerSpecialUse({&result}, {this, &other});
return result;
}
@ -2326,9 +2332,9 @@ NDArray NDArray::operator-() const {
NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformSame(getContext(), nd4j::transform::Neg, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -2631,7 +2637,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::vector<int>& di
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
registerSpecialUse({result}, {this, other});
NDArray::registerSpecialUse({result}, {this, other});
return;
}
@ -2688,7 +2694,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
registerSpecialUse({result}, {this, other});
NDArray::registerSpecialUse({result}, {this, other});
return;
}
@ -2896,7 +2902,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(this->rankOf(), this->_shapeInfo, shape.size(), shape.data(), shapeInfoNew);
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) {
@ -2948,11 +2954,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray* othe
if (target->dataType() != this->dataType() && target->dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !");
prepareSpecialUse({target}, {this, other});
NDArray::prepareSpecialUse({target}, {this, other});
NativeOpExecutioner::execPairwiseTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
registerSpecialUse({target}, {this, other});
NDArray::registerSpecialUse({target}, {this, other});
if (extraParams != nullptr)
synchronize("NDArray::applyPairwiseTransform");
@ -2969,9 +2973,9 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
if (dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
prepareSpecialUse({target}, {this, other});
NDArray::prepareSpecialUse({target}, {this, other});
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
registerSpecialUse({target}, {this, other});
NDArray::registerSpecialUse({target}, {this, other});
}
//////////////////////////////////////////////////////////////////////////
@ -3070,22 +3074,23 @@ void NDArray::assign(const NDArray& other) {
if (other.isScalar()) {
if(this->isScalar()) {
preparePrimaryUse({this}, {&other});
NDArray::preparePrimaryUse({this}, {&other});
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
registerPrimaryUse({this}, {&other});
NDArray::registerPrimaryUse({this}, {&other});
this->syncToDevice();
}
else {
if (dataType() != other.dataType()) {
auto tmp = other.cast(dataType());
prepareSpecialUse({this}, {tmp});
NDArray::prepareSpecialUse({this}, {tmp});
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {});
NDArray::registerSpecialUse({this}, {});
delete tmp;
}
else {
prepareSpecialUse({this}, {&other});
NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
registerSpecialUse({this}, {&other});
NDArray::registerSpecialUse({this}, {&other});
}
}
}
@ -3101,16 +3106,16 @@ void NDArray::assign(const NDArray& other) {
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else {
prepareSpecialUse({this}, {&other});
NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
registerSpecialUse({this}, {&other});
NDArray::registerSpecialUse({this}, {&other});
}
}
}
////////////////////////////////////////////////////////////////////////
// This method returns new copy of this NDArray, optionally in different order
NDArray* NDArray::dup(const char newOrder) {
NDArray* NDArray::dup(const char newOrder) const {
if (isEmpty())
return NDArrayFactory::empty_(dataType(), getContext());
@ -3170,7 +3175,7 @@ std::string NDArray::e(const Nd4jLong i) const {
if (!isS())
throw std::runtime_error("Can't get std::string out of non-string array");
preparePrimaryUse({}, {this});
NDArray::preparePrimaryUse({}, {this});
// getting "virtual" offset. it's not real though,since it doesn't take lengths into account
auto offset = getOffset(i);
@ -3208,8 +3213,8 @@ T NDArray::e(const Nd4jLong i) const {
const auto rp = getOffset(i);
preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this});
NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), rp), LIBND4J_TYPES);
}
@ -3226,8 +3231,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
const Nd4jLong coords[2] = {i, j};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this});
NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3246,8 +3251,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
const Nd4jLong coords[3] = {i, j, k};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this});
NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3266,8 +3271,8 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
const Nd4jLong coords[4] = {i, j, k, l};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({}, {this});
registerPrimaryUse({}, {this});
NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this});
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(getBuffer(), xOffset), LIBND4J_TYPES);
@ -3300,9 +3305,9 @@ void NDArray::applyTransform(nd4j::transform::FloatOps op, NDArray *target, Extr
if (!target->isR())
throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types");
prepareSpecialUse({target}, {this});
NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this});
NDArray::registerSpecialUse({target}, {this});
}
////////////////////////////////////////////////////////////////////////
@ -3314,9 +3319,9 @@ void NDArray::applyTransform(nd4j::transform::AnyOps op, NDArray *target, ExtraA
if (target == nullptr)
target = this;
prepareSpecialUse({target}, {this});
NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this});
NDArray::registerSpecialUse({target}, {this});
}
////////////////////////////////////////////////////////////////////////
@ -3331,9 +3336,9 @@ void NDArray::applyTransform(nd4j::transform::SameOps op, NDArray *target, Extra
if (target->dataType() != dataType())
throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array");
prepareSpecialUse({target}, {this});
NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this});
NDArray::registerSpecialUse({target}, {this});
}
////////////////////////////////////////////////////////////////////////
@ -3347,9 +3352,9 @@ void NDArray::applyTransform(nd4j::transform::StrictOps op, NDArray *target, Ext
if (!this->isR() || !target->isR() || (this->dataType() != target->dataType()))
throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !");
registerSpecialUse({target}, {this});
NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
prepareSpecialUse({target}, {this});
NDArray::registerSpecialUse({target}, {this});
}
////////////////////////////////////////////////////////////////////////
@ -3363,9 +3368,9 @@ void NDArray::applyTransform(nd4j::transform::BoolOps op, NDArray *target, Extra
if (!target->isB())
throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types");
prepareSpecialUse({target}, {this});
NDArray::prepareSpecialUse({target}, {this});
NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr, nullptr, nullptr);
registerSpecialUse({target}, {this});
NDArray::registerSpecialUse({target}, {this});
}
////////////////////////////////////////////////////////////////////////
@ -3375,9 +3380,9 @@ NDArray NDArray::transform(nd4j::transform::FloatOps op, void *extraParams) cons
NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext());
registerSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformFloat(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
prepareSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -3389,9 +3394,9 @@ NDArray NDArray::transform(nd4j::transform::SameOps op, void *extraParams) const
NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -3403,9 +3408,9 @@ NDArray NDArray::transform(nd4j::transform::StrictOps op, void *extraParams) con
NDArray result(getShapeInfo(), false, getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformStrict(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -3417,9 +3422,9 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
NDArray result(ordering(), getShapeAsVector(), nd4j::DataType::BOOL, getContext());
prepareSpecialUse({&result}, {this});
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr);
registerSpecialUse({&result}, {this});
NDArray::registerSpecialUse({&result}, {this});
return result;
}
@ -3435,9 +3440,9 @@ void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArra
if(target->dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar->getShapeInfo()) && !(target->dataType() == dataType() || target->dataType() == scalar->dataType()))
throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!");
prepareSpecialUse({target}, {this, scalar});
NDArray::prepareSpecialUse({target}, {this, scalar});
NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
registerSpecialUse({target}, {this, scalar});
NDArray::registerSpecialUse({target}, {this, scalar});
}
////////////////////////////////////////////////////////////////////////
@ -3471,10 +3476,9 @@ void NDArray::applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, ND
throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!");
}
prepareSpecialUse({target}, {this, scalar});
NDArray::prepareSpecialUse({target}, {this, scalar});
NativeOpExecutioner::execScalarBool(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
registerSpecialUse({target}, {this, scalar});
NDArray::registerSpecialUse({target}, {this, scalar});
}
////////////////////////////////////////////////////////////////////////
@ -3557,7 +3561,7 @@ NDArray* NDArray::applyReduce3(nd4j::reduce3::Ops op, const NDArray* other, cons
NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execReduce3Scalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo());
registerSpecialUse({result}, {this, other});
NDArray::registerSpecialUse({result}, {this, other});
return result;
}
@ -3635,9 +3639,9 @@ NDArray* NDArray::applyAllReduce3(nd4j::reduce3::Ops op, const NDArray *other, c
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
prepareSpecialUse({result}, {this, other});
NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execReduce3All(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
registerSpecialUse({result}, {this, other});
NDArray::registerSpecialUse({result}, {this, other});
return result;
}
@ -3780,9 +3784,9 @@ void NDArray::p(const Nd4jLong i, const T value) {
auto rp = getOffset(i);
const void *pV = reinterpret_cast<const void*>(const_cast<T *>(&value));
preparePrimaryUse({this}, {}, true);
NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->getBuffer(), rp, pV), LIBND4J_TYPES);
registerPrimaryUse({this}, {});
NDArray::registerPrimaryUse({this}, {});
}
template void NDArray::p(const Nd4jLong i, const double value);
@ -3811,9 +3815,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({this}, {}, true);
NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {});
NDArray::registerPrimaryUse({this}, {});
}
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value);
@ -3837,13 +3841,13 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2])
throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !");
preparePrimaryUse({this}, {}, true);
NDArray::preparePrimaryUse({this}, {}, true);
void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {});
NDArray::registerPrimaryUse({this}, {});
}
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value);
@ -3870,9 +3874,9 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
Nd4jLong coords[4] = {i, j, k, l};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf());
preparePrimaryUse({this}, {}, true);
NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
registerPrimaryUse({this}, {});
NDArray::registerPrimaryUse({this}, {});
}
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value);
template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value);
@ -3896,10 +3900,10 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
if (i >= _length)
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
preparePrimaryUse({this}, {&scalar}, true);
NDArray::preparePrimaryUse({this}, {&scalar}, true);
auto rp = getOffset(i);
BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (getBuffer(), rp, scalar.dataType(), scalar.getBuffer()), LIBND4J_TYPES);
registerPrimaryUse({this}, {&scalar});
NDArray::registerPrimaryUse({this}, {&scalar});
}
//////////////////////////////////////////////////////////////////////////
@ -4195,7 +4199,7 @@ ResultSet* NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
auto pack = ConstantTadHelper::getInstance()->tadForDimensions(_shapeInfo, const_cast<int*>(dimensions.data()), dimensions.size());
auto numTads = lengthOf() / shape::length(pack.primaryShapeInfo());
auto numTads = pack.numberOfTads();
for (int idx = 0; idx < numTads; idx++ ) {
auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + getBufferOffset());

View File

@ -1578,6 +1578,20 @@ public:
void *dx, Nd4jLong *dxShapeInfo,
bool descending);
void sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending);
void sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending);
void sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
@ -1587,6 +1601,24 @@ public:
Nd4jLong *tadOffsets,
bool descending);
void sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
void sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending);
// special sort impl for sorting out COO indices and values
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);

View File

@ -208,6 +208,23 @@ void* NDArray::specialBufferWithOffset(Nd4jLong offset) const {
return nullptr;
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
//////////////////////////////////////////////////////////////////////////
// change an array by repeating it the number of times given by reps.
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {

View File

@ -27,6 +27,52 @@
namespace nd4j {
////////////////////////////////////////////////////////////////////////
template <>
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
bool* hostBuffer = nullptr;
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
std::copy(data.begin(), data.end(), hostBuffer);
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
std::string s(str);
@ -227,10 +273,13 @@ template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<float16> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bfloat16> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned int> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<unsigned long> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int8_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<int16_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint16_t> &data, nd4j::LaunchContext * context);
template NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context);
@ -391,6 +440,7 @@ template NDArray NDArrayFactory::create(const std::vector<bfloat16> &values, nd4
template NDArray NDArrayFactory::create(const std::vector<Nd4jLong> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int16_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<uint16_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<int8_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<uint8_t> &values, nd4j::LaunchContext * context);
template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::LaunchContext * context);
@ -452,53 +502,6 @@ template NDArray NDArrayFactory::create(const std::vector<bool> &values, nd4j::L
return new NDArray(order, shape, dataType, context);
}
////////////////////////////////////////////////////////////////////////
template <typename T>
NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<T> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(DataTypeUtils::fromT<T>(), order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(data.data(), DataTypeUtils::fromT<T>(), descriptor.arrLength() * sizeof(T), context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <>
NDArray NDArrayFactory::create<bool>(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool> &data, nd4j::LaunchContext * context) {
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("NDArrayFactory::create: rank of NDArray can't exceed 32 !");
ShapeDescriptor descriptor(nd4j::DataType::BOOL, order, shape);
if (descriptor.arrLength() != data.size()) {
nd4j_printf("NDArrayFactory::create: data size [%i] doesn't match shape length [%lld]\n", data.size(), descriptor.arrLength());
throw std::runtime_error("NDArrayFactory::create: data size doesn't match shape");
}
bool* hostBuffer = nullptr;
ALLOCATE(hostBuffer, context->getWorkspace(), data.size(), bool);
std::copy(data.begin(), data.end(), hostBuffer);
std::shared_ptr<DataBuffer> buffer = std::make_shared<DataBuffer>(hostBuffer, data.size() * sizeof(bool), nd4j::DataType::BOOL, true, context->getWorkspace());
NDArray result(buffer, descriptor, context);
return result;
}
////////////////////////////////////////////////////////////////////////
template <typename T>
NDArray NDArrayFactory::create(T* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context) {

View File

@ -2736,6 +2736,60 @@ Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByKey(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortByValue(x, xShapeInfo, y, yShapeInfo, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByKey(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dx, Nd4jLong *dxShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto xType = ArrayOptions::dataType(xShapeInfo);
auto yType = ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, nd4j::DoubleMethods, ::sortTadByValue(x, xShapeInfo, y, yShapeInfo, dimension, dimensionLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);

View File

@ -192,8 +192,8 @@ void NDArray::setIdentity() {
if (isS())
throw std::runtime_error("NDArray::setIdentity: you can't use this method on String array!");
if (rankOf() != 2)
throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
// if (rankOf() != 2)
// throw std::runtime_error("NDArray::setIdentity: method should work only for 2D tensors. But " + toStringValue(rankOf()) + " was given.");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
@ -234,12 +234,15 @@ void NDArray::synchronize(const char* msg) const {
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList)
a->syncToDevice();
if(a != nullptr)
a->syncToDevice();
for (const auto& a : writeList) {
a->getDataBuffer()->allocateSpecial();
if (synchronizeWritables)
a->syncToDevice();
if (a != nullptr) {
a->getDataBuffer()->allocateSpecial();
if (synchronizeWritables)
a->syncToDevice();
}
}
}
@ -247,22 +250,27 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
for (const auto& p : readList)
p->tickReadDevice();
if(p != nullptr)
p->tickReadDevice();
for (const auto& p : writeList)
p->tickWriteDevice();
if (p != nullptr)
p->tickWriteDevice();
}
////////////////////////////////////////////////////////////////////////
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
for (const auto& a : readList)
if(a != nullptr)
a->syncToHost();
for (const auto& a : writeList) {
a->getDataBuffer()->allocatePrimary();
if (synchronizeWritables)
a->syncToHost();
if (a != nullptr) {
a->getDataBuffer()->allocatePrimary();
if (synchronizeWritables)
a->syncToHost();
}
}
}
@ -270,10 +278,12 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
for (const auto& p : readList)
p->tickReadHost();
if(p != nullptr)
p->tickReadHost();
for (const auto& p : writeList)
p->tickWriteHost();
if (p != nullptr)
p->tickWriteHost();
}
//////////////////////////////////////////////////////////////////////////
@ -427,9 +437,26 @@ void NDArray::repeat(int dimension, NDArray& target) const {
NDArray::registerSpecialUse({&target}, {this});
}
////////////////////////////////////////////////////////////////////////
void* NDArray::specialBuffer() {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
////////////////////////////////////////////////////////////////////////
void* NDArray::getSpecialBuffer() const {
if (_buffer->special() == nullptr)
return getBuffer();
// FIXME: this should be fixed once CUDA backend added
return static_cast<int8_t*>(_buffer->special()) + (_offset * sizeOfT());
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {\
void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const {
if(_length == 0)
{ printf("NDArray::printActualBuffer: array length is zero !\n"); return; }
@ -477,7 +504,7 @@ template void NDArray::printCurrentBuffer<double>(const bool host, const char* m
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
#include <cpu/NDArrayLambda.hpp>
//#include <cpu/NDArrayLambda.hpp>
#endif

View File

@ -2321,6 +2321,163 @@ void NativeOps::sort(Nd4jPointer *extraPointers,
}
void NativeOps::sortByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
dim3 launchDims(numBlocks, numThreads, 32768);
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
dim3 launchDims(numBlocks, numThreads, 32768);
int max = 2, dg = 0;
while (max < xLength) {
max <<= 1;
dg++;
}
max <<= 1;
for (int window = 2; window < max; window<<=1) {
int n = window;
int rev = 0;
do{
int half = n >> 1;
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
}
}
}
void NativeOps::sortByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
// check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
dim3 launchDims(numBlocks, numThreads, 32768);
for (int k = 2; k <= xLength; k = 2*k) {
for (int j = k >> 1; j > 0; j = j >> 1) {
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicSortStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, j, k, xLength, descending), LIBND4J_TYPES, LIBND4J_TYPES);
}
}
} else {
int numThreads = nd4j::math::nd4j_min<int>(512, xLength);
int numBlocks = xLength / numThreads;
if (xLength % numThreads > 0 || numBlocks == 0)
numBlocks++;
numBlocks = nd4j::math::nd4j_min<int>(512, numBlocks);
dim3 launchDims(numBlocks, numThreads, 32768);
int max = 2, dg = 0;
while (max < xLength) {
max <<= 1;
dg++;
}
max <<= 1;
for (int window = 2; window < max; window<<=1) {
int n = window;
int rev = 0;
do{
int half = n >> 1;
BUILD_DOUBLE_SELECTOR(xType, yType, bitonicArbitraryStepGenericValue, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, n, xLength, rev, descending), LIBND4J_TYPES, LIBND4J_TYPES);
n>>=1;
rev = 1;
} while(n > 1);
}
}
}
void NativeOps::sortTadByKey(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dX, dXShapeInfo, dy, dyShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadKey(...) failed");
}
void NativeOps::sortTadByValue(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dy, Nd4jLong *dyShapeInfo,
int *dimension,
int dimensionLength,
bool descending) {
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims((int) tadPack.numberOfTads(), 256, 2048);
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, yType, oesTadGenericKey, (launchDims, stream, dy, dyShapeInfo, dX, dXShapeInfo, nullptr, dimensionLength, tadPack.platformShapeInfo(), tadPack.platformOffsets(), descending), LIBND4J_TYPES, LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadValue(...) failed");
}
void NativeOps::sortTad(Nd4jPointer *extraPointers,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
@ -2331,15 +2488,13 @@ void NativeOps::sortTad(Nd4jPointer *extraPointers,
bool descending) {
// to be implemented
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto context = extraPointers[0] == 0 ? LaunchContext::defaultContext(): reinterpret_cast<LaunchContext*>(extraPointers[0]);
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
dim3 launchDims(tadPack.numberOfTads(), 1024, 33768);
dim3 launchDims((int) tadPack.numberOfTads(), 512, 33768);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(xType, oesTadGeneric, (launchDims, stream, dX, dXShapeInfo, nullptr, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
nd4j::DebugHelper::checkErrorCode(stream, "sortTadFloat(...) failed");
nd4j::DebugHelper::checkErrorCode(stream, "sortTad(...) failed");
}
void NativeOps::sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {

View File

@ -38,11 +38,11 @@ namespace nd4j {
ConstantDataBuffer() = default;
~ConstantDataBuffer() = default;
Nd4jLong sizeOf();
Nd4jLong length();
Nd4jLong sizeOf() const;
Nd4jLong length() const;
Nd4jPointer primary();
Nd4jPointer special();
Nd4jPointer primary() const;
Nd4jPointer special() const;
ConstantDataBuffer& operator=(const ConstantDataBuffer& other) = default;
ConstantDataBuffer& operator=(ConstantDataBuffer&& other) noexcept = default;

View File

@ -261,6 +261,8 @@ DataBuffer& DataBuffer::operator=(const DataBuffer& other) {
allocateBuffers();
copyBufferFrom(other);
return *this;
}
////////////////////////////////////////////////////////////////////////
@ -285,6 +287,8 @@ DataBuffer& DataBuffer::operator=(DataBuffer&& other) noexcept {
other._primaryBuffer = other._specialBuffer = nullptr;
other.setAllocFlags(false, false);
other._lenInBytes = 0;
return *this;
}
////////////////////////////////////////////////////////////////////////

View File

@ -28,7 +28,7 @@
#include <op_boilerplate.h>
#include <dll.h>
#include <Environment.h>
#include <ArrayOptions.h>
#include <ArrayOptions.h>
#include <templatemath.h>
#include <shape.h>
#include <helpers/logger.h>
@ -62,7 +62,7 @@ namespace nd4j {
template <typename T>
FORCEINLINE static _CUDA_HD T nanOrZero();
// returns the difference between 1.0 and the next representable value of the given floating-point type
// returns the difference between 1.0 and the next representable value of the given floating-point type
template <typename T>
FORCEINLINE static T eps();
@ -94,13 +94,13 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////
///// IMLEMENTATION OF INLINE METHODS /////
///// IMLEMENTATION OF INLINE METHODS /////
//////////////////////////////////////////////////////////////////////////
FORCEINLINE nd4j::DataType DataTypeUtils::pickFloatingType(nd4j::DataType typeX) {
// if proposed dataType is already floating point - return it
if (isR(typeX))
return typeX;
return typeX;
return Environment::getInstance()->defaultFloatDataType();
}
@ -213,13 +213,13 @@ FORCEINLINE _CUDA_HD uint32_t DataTypeUtils::min<uint32_t>() {
}
template<>
FORCEINLINE _CUDA_HD float DataTypeUtils::min<float>() {
return 1.175494e-38;
FORCEINLINE _CUDA_HD float DataTypeUtils::min<float>() {
return 1.175494e-38;
}
template<>
FORCEINLINE _CUDA_HD float16 DataTypeUtils::min<float16>() {
return (float16) 6.1035e-05;
return (float16) 6.1035e-05;
}
template<>
@ -228,8 +228,8 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::min<bfloat16>() {
}
template<>
FORCEINLINE _CUDA_HD double DataTypeUtils::min<double>() {
return 2.2250738585072014e-308;
FORCEINLINE _CUDA_HD double DataTypeUtils::min<double>() {
return 2.2250738585072014e-308;
}
///////////////////////////////////////////////////////////////////
@ -280,17 +280,17 @@ FORCEINLINE _CUDA_HD Nd4jULong DataTypeUtils::max<Nd4jULong>() {
}
template <>
FORCEINLINE _CUDA_HD float DataTypeUtils::max<float>() {
FORCEINLINE _CUDA_HD float DataTypeUtils::max<float>() {
return 3.402823e+38;
}
template <>
FORCEINLINE _CUDA_HD double DataTypeUtils::max<double>() {
return 1.7976931348623157E308;
FORCEINLINE _CUDA_HD double DataTypeUtils::max<double>() {
return 1.7976931348623157E308;
}
template <>
FORCEINLINE _CUDA_HD float16 DataTypeUtils::max<float16>() {
FORCEINLINE _CUDA_HD float16 DataTypeUtils::max<float16>() {
return static_cast<float16>(65504.f);
}
@ -335,6 +335,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("INT8");
case INT16:
return std::string("INT16");
case UINT16:
return std::string("UINT16");
case INT32:
return std::string("INT32");
case INT64:
@ -361,7 +363,7 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
template <typename T>
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
for (int e = 0; e < shape::shapeInfoLength(originalShapeInfo); e++) {
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
@ -373,9 +375,9 @@ FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo,
}
///////////////////////////////////////////////////////////////////
// returns the difference between 1.0 and the next representable value of the given floating-point type
// returns the difference between 1.0 and the next representable value of the given floating-point type
template <typename T>
FORCEINLINE T DataTypeUtils::eps() {
FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
if (std::is_same<T, double>::value)
return std::numeric_limits<double>::epsilon();
else if (std::is_same<T, float>::value)
@ -406,7 +408,7 @@ FORCEINLINE T DataTypeUtils::eps() {
case nd4j::DataType::FLOAT8:
case nd4j::DataType::QINT8:
case nd4j::DataType::BOOL: return (size_t) 1;
case nd4j::DataType::BFLOAT16:
case nd4j::DataType::HALF:
case nd4j::DataType::INT16:

View File

@ -26,6 +26,7 @@
#include <vector>
#include <array/DataType.h>
#include <pointercast.h>
#include <stdlib.h>
namespace nd4j {
class ND4J_EXPORT ExtraArguments {

View File

@ -35,21 +35,21 @@ namespace nd4j {
TadPack() = default;
~TadPack() = default;
Nd4jLong* primaryShapeInfo();
Nd4jLong* primaryOffsets();
Nd4jLong* primaryShapeInfo() const;
Nd4jLong* primaryOffsets() const;
Nd4jLong* specialShapeInfo();
Nd4jLong* specialOffsets();
Nd4jLong* specialShapeInfo() const;
Nd4jLong* specialOffsets() const;
Nd4jLong numberOfTads();
int shapeInfoLength();
Nd4jLong numberOfTads() const;
int shapeInfoLength() const;
/**
* These methods return either primary or special pointers depending on platform binaries were compiled for
* @return
*/
Nd4jLong *platformShapeInfo();
Nd4jLong *platformOffsets();
Nd4jLong *platformShapeInfo() const;
Nd4jLong *platformOffsets() const;
};
}

View File

@ -28,19 +28,19 @@ namespace nd4j {
_sizeOf = sizeOf;
}
Nd4jPointer ConstantDataBuffer::primary() {
Nd4jPointer ConstantDataBuffer::primary() const {
return _primaryBuffer;
}
Nd4jPointer ConstantDataBuffer::special() {
Nd4jPointer ConstantDataBuffer::special() const {
return _specialBuffer;
}
Nd4jLong ConstantDataBuffer::sizeOf() {
Nd4jLong ConstantDataBuffer::sizeOf() const {
return _sizeOf;
}
Nd4jLong ConstantDataBuffer::length() {
Nd4jLong ConstantDataBuffer::length() const {
return _length;
}

View File

@ -54,7 +54,7 @@ namespace nd4j {
NDArray* NDArrayList::readRaw(int idx) {
if (_chunks.count(idx) < 1) {
nd4j_printf("Non-existent chunk requested: [%i]\n", idx);
throw std::runtime_error("Bad index");
throw std::invalid_argument("Bad index");
}
return _chunks[idx];
@ -120,7 +120,7 @@ namespace nd4j {
// storing reference
_chunks[idx] = array;
return ND4J_STATUS_OK;
return Status::OK();
}
std::vector<Nd4jLong>& NDArrayList::shape() {
@ -152,8 +152,10 @@ namespace nd4j {
std::vector<bool> bargs;
int numElements = _elements.load();
for (int e = 0; e < numElements; e++)
for (int e = 0; e < numElements; e++) {
_chunks[e]->syncToDevice();
inputs.emplace_back(_chunks[e]);
}
iargs.push_back(_axis);

View File

@ -29,34 +29,34 @@ namespace nd4j {
_numTads = numTads;
}
Nd4jLong* TadPack::primaryShapeInfo() {
Nd4jLong* TadPack::primaryShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.primary());
}
Nd4jLong* TadPack::primaryOffsets() {
Nd4jLong* TadPack::primaryOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.primary());
}
Nd4jLong* TadPack::specialShapeInfo() {
Nd4jLong* TadPack::specialShapeInfo() const {
return reinterpret_cast<Nd4jLong *>(_tadShape.special());
}
Nd4jLong* TadPack::specialOffsets() {
Nd4jLong* TadPack::specialOffsets() const {
return reinterpret_cast<Nd4jLong *>(_tadOffsets.special());
}
Nd4jLong TadPack::numberOfTads() {
Nd4jLong TadPack::numberOfTads() const {
return _numTads;
}
Nd4jLong* TadPack::platformShapeInfo() {
Nd4jLong* TadPack::platformShapeInfo() const {
return nd4j::Environment::getInstance()->isCPU() ? primaryShapeInfo() : specialShapeInfo();
}
Nd4jLong* TadPack::platformOffsets() {
Nd4jLong* TadPack::platformOffsets() const {
return nd4j::Environment::getInstance()->isCPU() ? primaryOffsets() : specialOffsets();
}
int TadPack::shapeInfoLength() {
int TadPack::shapeInfoLength() const {
return (int) shape::shapeInfoLength(primaryShapeInfo());
}
}

View File

@ -27,7 +27,7 @@ namespace nd4j {
class AttentionHelper {
public:
static nd4j::NDArray* multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static nd4j::NDArray multiHeadProject(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static void multiHeadProjectBp(const nd4j::NDArray* input, const nd4j::NDArray* projectionMatrix, const nd4j::NDArray* eps, nd4j::NDArray* dLdInput, nd4j::NDArray* dLdProjectionMatrix, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
};
}

View File

@ -69,10 +69,10 @@ namespace nd4j {
}
void executeOnce() override {
auto xT = (_tA ? _x->transpose() : _x);
auto yT = (_tB ? _y->transpose() : _y);
auto xT = (_tA ? _x->transpose() : *_x);
auto yT = (_tB ? _y->transpose() : *_y);
MmulHelper::mmul(xT, yT, _z, _alpha, _beta);
MmulHelper::mmul(&xT, &yT, _z, _alpha, _beta);
}
std::string axis() override {

View File

@ -39,31 +39,31 @@ NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
T coeff;
T normX = x.reduceNumber(reduce::Norm2).e<T>(0);
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
normX = x.e<T>(0);
coeff = 0.f;
w = 0.f;
}
}
else {
if(x.e<T>(0) >= (T)0.f)
normX = -normX; // choose opposite sign to lessen roundoff error
T u0 = x.e<T>(0) - normX;
coeff = -u0 / normX;
w.assign(x / u0);
coeff = -u0 / normX;
w.assign(x / u0);
}
w.p(Nd4jLong(0), 1.f);
wT.assign(&w);
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
identity.setIdentity(); // identity matrix
return identity - mmul(w, wT) * coeff;
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
identity.setIdentity(); // identity matrix
return identity - mmul(w, wT) * coeff;
}
@ -79,7 +79,7 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
throw std::runtime_error("ops::helpers::Householder::evalHHmatrixData method: input tail vector must have length less than unity compared to input x vector!");
normX = x.reduceNumber(reduce::Norm2, nullptr).e<T>(0);
if(normX*normX - x.e<T>(0) * x.e<T>(0) <= DataTypeUtils::min<T>() || x.lengthOf() == 1) {
normX = x.e<T>(0);
@ -87,18 +87,18 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
tail = (T)0.f;
}
else {
if(x.e<T>(0) >= (T)0.f)
normX = -normX; // choose opposite sign to lessen roundoff error
T u0 = x.e<T>(0) - normX;
coeff = -u0 / normX;
coeff = -u0 / normX;
if(x.isRowVector())
tail.assign(x({0,0, 1,-1}) / u0);
tail.assign(x({0,0, 1,-1}) / u0);
else
tail.assign(x({1,-1, 0,0,}) / u0);
}
tail.assign(x({1,-1, 0,0,}) / u0);
}
}
//////////////////////////////////////////////////////////////////////////
@ -107,20 +107,20 @@ void Householder<T>::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) {
int rows = (int)x.lengthOf()-1;
int num = 1;
if(rows == 0) {
rows = 1;
num = 0;
}
}
auto tail = NDArrayFactory::create(x.ordering(), {rows, 1}, x.dataType(), x.getContext());
evalHHmatrixData(x, tail, coeff, normX);
if(x.isRowVector()) {
auto temp = x({0,0, num, x.sizeAt(1)}, true);
temp.assign(tail);
temp.assign(tail);
}
else {
else {
auto temp = x({num,x.sizeAt(0), 0,0}, true);
temp.assign(tail);
}
@ -129,14 +129,14 @@ void Householder<T>::evalHHmatrixDataI(const NDArray& x, T& coeff, T& normX) {
//////////////////////////////////////////////////////////////////////////
template <typename T>
void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff) {
// if(matrix.rankOf() != 2)
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
if(matrix.sizeAt(0) == 1)
matrix *= (T)1.f - coeff;
else if(coeff != (T)0.f) {
// if(matrix.rankOf() != 2)
// throw "ops::helpers::Householder::mulLeft method: input array must be 2D matrix !";
if(matrix.sizeAt(0) == 1) {
matrix *= (T) 1.f - coeff;
}
else if(coeff != (T)0.f) {
auto bottomPart = new NDArray(matrix({1,matrix.sizeAt(0), 0,0}, true));
auto bottomPartCopy = *bottomPart;
@ -145,26 +145,22 @@ void Householder<T>::mulLeft(NDArray& matrix, const NDArray& tail, const T coeff
auto column = tail;
auto row = tail.transpose();
auto resultingRow = mmul(*row, bottomPartCopy);
auto resultingRow = mmul(row, bottomPartCopy);
auto fistRow = matrix({0,1, 0,0}, true);
resultingRow += fistRow;
fistRow -= resultingRow * coeff;
*bottomPart -= mmul(column, resultingRow) * coeff;
delete row;
resultingRow += fistRow;
fistRow -= resultingRow * coeff;
*bottomPart -= mmul(column, resultingRow) * coeff;
}
else {
auto row = tail;
auto column = tail.transpose();
auto resultingRow = mmul(row, bottomPartCopy);
auto fistRow = matrix({0,1, 0,0}, true);
resultingRow += fistRow;
fistRow -= resultingRow * coeff;
*bottomPart -= mmul(*column, resultingRow) * coeff;
delete column;
}
*bottomPart -= mmul(column, resultingRow) * coeff;
}
delete bottomPart;
}
}
@ -176,10 +172,10 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
// if(matrix.rankOf() != 2)
// throw "ops::helpers::Householder::mulRight method: input array must be 2D matrix !";
if(matrix.sizeAt(1) == 1)
if(matrix.sizeAt(1) == 1)
matrix *= (T)1.f - coeff;
else if(coeff != (T)0.f) {
auto rightPart = new NDArray(matrix({0,0, 1,matrix.sizeAt(1)}, true));
@ -191,30 +187,25 @@ void Householder<T>::mulRight(NDArray& matrix, const NDArray& tail, const T coef
auto column = tail;
auto row = tail.transpose();
auto resultingCol = mmul(rightPartCopy, column);
resultingCol += *fistCol;
*fistCol -= resultingCol * coeff;
*rightPart -= mmul(resultingCol, *row) * coeff;
delete row;
}
else {
auto row = tail;
auto column = tail.transpose();
auto resultingCol = mmul(rightPartCopy, *column);
resultingCol += *fistCol;
resultingCol += *fistCol;
*fistCol -= resultingCol * coeff;
*rightPart -= mmul(resultingCol, row) * coeff;
}
else {
delete column;
}
auto row = tail;
auto column = tail.transpose();
auto resultingCol = mmul(rightPartCopy, column);
resultingCol += *fistCol;
*fistCol -= resultingCol * coeff;
*rightPart -= mmul(resultingCol, row) * coeff;
}
delete rightPart;
delete fistCol;
}
}
template class ND4J_EXPORT Householder<float>;
template class ND4J_EXPORT Householder<float16>;
template class ND4J_EXPORT Householder<bfloat16>;

View File

@ -157,8 +157,7 @@ bool JacobiSVD<T>::isBlock2x2NotDiag(NDArray& block, int p, int q, T& maxElem) {
if(_calcU) {
auto temp2 = rotation.transpose();
mulRotationOnRight(p, q, _u, *temp2);
delete temp2;
mulRotationOnRight(p, q, _u, temp2);
}
}
@ -251,9 +250,7 @@ void JacobiSVD<T>::svd2x2(const NDArray& block, int p, int q, NDArray& left, NDA
m.p<T>(1, 1, _z);
auto temp = right.transpose();
left.assign(mmul(rotation, *temp));
delete temp;
left.assign(mmul(rotation, temp));
}
@ -289,7 +286,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
else if(_rows < _cols) {
auto matrixT = matrix.transpose();
HHcolPivQR qr(*matrixT / scale);
HHcolPivQR qr(matrixT / scale);
_m.assign(qr._qr({0,_rows, 0,_rows}));
_m.fillAsTriangular<T>(0., 0, 0, 'l');
_m.transposei();
@ -305,8 +302,6 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
if(_calcU)
_u.assign(qr._permut);
delete matrixT;
}
else {
@ -352,8 +347,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
if(_calcU) {
auto temp = rotLeft.transpose();
mulRotationOnRight(p, q, _u, *temp);
delete temp;
mulRotationOnRight(p, q, _u, temp);
}
mulRotationOnRight(p, q, _m, rotRight);

View File

@ -920,7 +920,7 @@ void SVD<T>::evalData(const NDArray& matrix) {
auto temp1 = biDiag._HHbidiag.transpose();
auto temp2 = _m({0,_diagSize, 0,0}, true);
temp2.assign(temp1);
delete temp1;
auto temp3 = _m({_m.sizeAt(0)-1,_m.sizeAt(0), 0,0}, true);
temp3.assign(0.);

View File

@ -184,9 +184,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
if(pC->ordering() != 'f') {
auto temp = pA;
pA = pB ->permute({1,0});
pB = temp->permute({1,0});
pC = pC ->permute({1,0});
pA = new NDArray(pB ->permute({1,0}));
pB = new NDArray(temp->permute({1,0}));
pC = new NDArray(pC ->permute({1,0}));
toDelete.push_back(pA);
toDelete.push_back(pB);
toDelete.push_back(pC);
@ -251,7 +251,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
}
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
@ -339,7 +340,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
threadsPerBlock.x = 512;
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
}
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES)
}
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
@ -396,7 +398,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
NDArray::prepareSpecialUse({Z}, {X, Y});
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES)
auto cudaResult = cudaStreamSynchronize(*stream);
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
@ -406,8 +409,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
return Z;
}
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
}

View File

@ -28,33 +28,27 @@
namespace nd4j {
nd4j::NDArray *
AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
nd4j::NDArray AttentionHelper::multiHeadProject(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix, nd4j::LaunchContext * context) {
auto miniBatchSize = input->sizeAt(0);
auto seqLength = input->sizeAt(2);
auto numHeads = projectionMatrix->sizeAt(0);
auto projectedSize = projectionMatrix->sizeAt(1);
auto inputPerm = input->permute({1, 0, 2});
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
NDArray* projected = new NDArray('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
nd4j::ops::matmul mmul;
mmul.execute({projectionPrep, inputPrep}, {projected}, {}, {}, {});
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
projected->reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected->permutei({2, 0, 1, 3});
delete inputPerm;
delete inputPrep;
delete projectionPrep;
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected.permutei({2, 0, 1, 3});
return projected;
}
void
AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
void AttentionHelper::multiHeadProjectBp(const nd4j::NDArray *input, const nd4j::NDArray *projectionMatrix,
const nd4j::NDArray *eps, nd4j::NDArray *dLdInput,
nd4j::NDArray *dLdProjectionMatrix, nd4j::LaunchContext * context) {
auto miniBatchSize = input->sizeAt(0);
@ -63,16 +57,16 @@ namespace nd4j {
auto projectedSize = projectionMatrix->sizeAt(1);
auto epsPerm = eps->permute({1, 2, 0, 3});
auto epsReshaped = epsPerm->reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});
auto inputPerm = input->permute({1, 0, 2});
auto inputPrep = inputPerm->reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
nd4j::ops::matmul_bp mmulBp;
NDArray dLdProjectionPrep(projectionPrep->shapeInfo(), false, context);
NDArray dLdInputPrep(inputPrep->shapeInfo(), false, context);
mmulBp.execute({projectionPrep, inputPrep, epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, {&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});
dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
dLdProjectionMatrix->assign(dLdProjectionPrep);
@ -80,12 +74,6 @@ namespace nd4j {
dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
dLdInputPrep.permutei({1, 0, 2});
dLdInput->assign(dLdInputPrep);
delete inputPerm;
delete inputPrep;
delete epsPerm;
delete epsReshaped;
delete projectionPrep;
}
}

View File

@ -29,13 +29,13 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
const int numInGradArrs = gradArrs.size();
// fill input gradient arrays in accordance to type of loss function
// fill input gradient arrays in accordance to type of loss function
switch(loss) {
case MEAN:
PRAGMA_OMP_PARALLEL_FOR_IF(numInGradArrs > 1)
for(int i = 0; i < numInGradArrs; ++i)
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
for(int i = 0; i < numInGradArrs; ++i)
*gradArrs[i] = 1. / gradArrs[i]->lengthOf();
break;
case SUM:
@ -43,9 +43,9 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
for(int i = 0; i < numInGradArrs; ++i)
*gradArrs[i] = 1.;
break;
default:
throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !");
default:
throw std::invalid_argument("GradCheck::fillGradArrays: invalid type of loss function !");
}
}
@ -53,7 +53,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector<NDArray*>&
bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP,
const std::vector<bool>& whatArrsToCheck, const std::vector<double>& idxRange, const LossFunc loss ) {
const int numInArrsFF = argsHolderFF.getNumInArrs(); // also numInArrsFF = number of output arrays in opBP
const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP
const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
@ -61,10 +61,11 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
// fill input gradient arrays in accordance to type of loss function
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
// beck prop pass
// beck prop pass
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
for(int i = 0; i < numInArrsFF; ++i) { // loop through input array
if(!whatArrsToCheck.empty() && static_cast<bool>(whatArrsToCheck[i]) == false)
@ -72,42 +73,42 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
const Nd4jLong idxStart = static_cast<Nd4jLong>(idxRange[0] * inArrsFF[i]->lengthOf());
const Nd4jLong idxEnd = static_cast<Nd4jLong>(idxRange[1] * inArrsFF[i]->lengthOf());
for(Nd4jLong j = idxStart; j < idxEnd; ++j) { // loop through all elements for current array
double& elem = inArrsFF[i]->t<double>(j);
const double orig = elem;
const double orig = inArrsFF[i]->e<double>(j);
// add epsilon, feed forward
elem = orig + EPSILON;
inArrsFF[i]->p<double>(j, orig + EPSILON);
ResultSet* outArrsFF = opFF.execute(argsHolderFF);
int numOutArrs = outArrsFF->size();
double scorePlus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
double scorePlus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
scorePlus += tmpScalar.e<double>(0);
}
delete outArrsFF;
// subtract epsilon, feed forward
elem = orig - EPSILON;
inArrsFF[i]->p<double>(j, orig - EPSILON);
outArrsFF = opFF.execute(argsHolderFF);
double scoreMinus = 0.;
for(int k = 0; k < numOutArrs; ++k) { // loop through output array
for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays
if(loss == SUM)
NativeOpExecutioner::execReduceSameScalar(LaunchContext::defaultContext(), reduce::Sum, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
outArrsFF->at(k)->reduceNumber(reduce::Sum, tmpScalar);
else
NativeOpExecutioner::execReduceFloatScalar(LaunchContext::defaultContext(), reduce::Mean, outArrsFF->at(k)->getBuffer(), outArrsFF->at(k)->getShapeInfo(), outArrsFF->at(k)->getSpecialBuffer(), outArrsFF->at(k)->getSpecialShapeInfo(), nullptr, tmpScalar.buffer(), tmpScalar.shapeInfo(), tmpScalar.specialBuffer(), tmpScalar.specialShapeInfo());
outArrsFF->at(k)->reduceNumber(reduce::Mean, tmpScalar);
scoreMinus += tmpScalar.e<double>(0);
}
delete outArrsFF;
// restore initial element value
elem = orig;
inArrsFF[i]->p<double>(j, orig);
// calculate numerical gradient
const double numericalGrad = (scorePlus - scoreMinus) / (2 * EPSILON);
@ -116,7 +117,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
throw std::runtime_error("");
}
// get analytical gradient
// get analytical gradient
const double analyticGrad = outArrsBP->at(i)->e<double>(j);
if(std::isnan(analyticGrad) || std::isinf(analyticGrad)) {
printf("GradCheck::checkGrad: got wrong value for analytical gradient for input array # %i and its element at position %lld ! \n", i, j);
@ -124,13 +125,13 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
}
// printf("num = %.5f, ana = %.5f\n", numericalGrad, analyticGrad);
// calculate relative error
double relError;
if(numericalGrad == 0. && analyticGrad == 0.)
relError = 0.;
else
relError = math::nd4j_abs<double>(analyticGrad - numericalGrad) / (math::nd4j_abs<double>(analyticGrad) + math::nd4j_abs<double>(numericalGrad));
relError = math::nd4j_abs<double>(analyticGrad - numericalGrad) / (math::nd4j_abs<double>(analyticGrad) + math::nd4j_abs<double>(numericalGrad));
// verify result
if(relError > MAXRELERR || std::isnan(relError)) {
@ -144,7 +145,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
}
}
}
delete outArrsBP;
return true;
}

View File

@ -39,26 +39,23 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* A, const nd4j::N
nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<int>& axes_0, const std::vector<int>& axes_1) {
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray* aPR = a->permute(permutAt);
NDArray* bPR = b->permute(permutBt);
// check whether reshape is necessary
if(!aPR->isSameShape(shapeAt))
aPR->reshapei( shapeAt);
if(!bPR->isSameShape(shapeBt))
bPR->reshapei( shapeBt);
NDArray aPR = a->permute(permutAt);
NDArray bPR = b->permute(permutBt);
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei( shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei( shapeBt);
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape);
delete aPR;
delete bPR;
return c;
}
@ -74,65 +71,67 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
// check whether permutation is required
if(!permutForC.empty())
cP = c->permute(permutForC);
cP = new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt);
auto bPR = b->permute(permutBt);
// check whether reshape is necessary
if(!aPR->isSameShape(shapeAt))
aPR->reshapei(shapeAt);
if(!bPR->isSameShape(shapeBt))
bPR->reshapei(shapeBt);
if(!aPR.isSameShape(shapeAt))
aPR.reshapei(shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei(shapeBt);
if(!cP->isSameShape({aPR->sizeAt(0), bPR->sizeAt(1)}))
cPR = cP->reshape(cP->ordering(), {aPR->sizeAt(0), bPR->sizeAt(1)});
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
mmul(aPR, bPR, cPR, 1.0, 0.0);
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR);
if(cPR != c)
delete cPR;
if(cP != c)
delete cP;
delete aPR;
delete bPR;
}
#ifndef __JAVACPP_HACK__
//////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB, const std::vector<std::vector<Nd4jLong>>& modifC) {
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
std::string whatToDoWithA, whatToDoWithB, whatToDoWithC; // "" - nothing; "p" - permutation; "r" - reshaping; "pr" - permutation+reshaping; "rp" - reshaping/permutation, and so on; if another string is produced - throw exception
for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
for(const auto& arr : modifC)
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
for(const auto& arr : modifC)
whatToDoWithC = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithC + "p" : whatToDoWithC + "r";
// first step for a array
if(!whatToDoWithA.empty())
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
// first step for b array
if(!whatToDoWithB.empty())
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
// rest steps for a array
for(int i = 1; i < whatToDoWithA.size(); ++i)
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
// rest steps for b array
for(int i = 1; i < whatToDoWithB.size(); ++i)
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
// now work with c array
std::vector<NDArray*> cArrs = {c};
if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? cArrs[i]->permute(modifC[i]) : cArrs[i]->reshape(c->ordering(), modifC[i]); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
}
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
// check whether new buffer allocation was happened for c array
@ -152,27 +151,30 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
//////////////////////////////////////////////////////////////////////////
NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, const std::vector<std::vector<Nd4jLong>>& modifA, const std::vector<std::vector<Nd4jLong>>& modifB) {
NDArray *aPR(const_cast<NDArray*>(a)), *bPR(const_cast<NDArray*>(b));
std::string whatToDoWithA, whatToDoWithB; // "" - nothing; "p" - permutation only; "r" - reshaping only; "pr" - permutation+reshaping; "rp" - reshaping/permutation; another string - throw exception
for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
for(const auto& arr : modifA)
whatToDoWithA = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithA + "p" : whatToDoWithA + "r"; // when 0 is present in arr then it is permutation array, otherwise - it is reshaping array
for(const auto& arr : modifB)
whatToDoWithB = (std::find(arr.begin(), arr.end(), 0) != arr.end()) ? whatToDoWithB + "p" : whatToDoWithB + "r";
// first step for a array
if(!whatToDoWithA.empty())
aPR = (whatToDoWithA[0] == 'p') ? a->permute(modifA[0]) : a->reshape(a->ordering(), modifA[0]);
aPR = (whatToDoWithA[0] == 'p') ? new NDArray(a->permute(modifA[0])) : new NDArray(a->reshape(a->ordering(), modifA[0]));
// first step for b array
if(!whatToDoWithB.empty())
bPR = (whatToDoWithB[0] == 'p') ? b->permute(modifB[0]) : b->reshape(b->ordering(), modifB[0]);
bPR = (whatToDoWithB[0] == 'p') ? new NDArray(b->permute(modifB[0])) : new NDArray(b->reshape(b->ordering(), modifB[0]));
// rest steps for a array
for(int i = 1; i < whatToDoWithA.size(); ++i)
if(whatToDoWithA[i] == 'p') aPR->permutei(modifA[i]); else aPR->reshapei(modifA[i]);
// rest steps for b array
for(int i = 1; i < whatToDoWithB.size(); ++i)
if(whatToDoWithB[i] == 'p') bPR->permutei(modifB[i]); else bPR->reshapei(modifB[i]);
NDArray* result = mmul(aPR, bPR, nullptr, 1.0, 0.0);
if(aPR != a)
delete aPR;
if(bPR != b)
@ -281,9 +283,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
nd4j_printf("NDArrayFactory::matmul static method: input shape of output array is wrong, actual is %s and expected is %s ! \n", ShapeUtils::shapeAsString(z).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
throw std::invalid_argument("");
}
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
if((transX && xRank > 1) || (transY && yRank > 1)) {
const int rank = xRank >= yRank ? xRank : yRank;
std::vector<int> permut(rank);
@ -291,25 +293,25 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
permut[i] = i;
permut[rank-2] = rank - 1;
permut[rank-1] = rank - 2;
if(transX)
xT = x->permute(permut);
xT = new NDArray(x->permute(permut));
if(transY)
yT = y->permute(permut);
yT = new NDArray(y->permute(permut));
}
if(xRank <= 2 && yRank <= 2) { // dot (1Dx1D), vector-matrix (1Dx2D), matrix-vector (2Dx1D), matrix-matrix (2Dx2D) product cases
if(xRank == 1 && yRank == 2) { // reduce vector-matrix to matrix-matrix case
xT = x->reshape(x->ordering(), {1, x->lengthOf()}); // please note x is not transposed in this case (since xRank=1)
zT = z->reshape(z->ordering(), {1, z->lengthOf()});
xT = new NDArray(x->reshape(x->ordering(), {1, x->lengthOf()})); // please note x is not transposed in this case (since xRank=1)
zT = new NDArray(z->reshape(z->ordering(), {1, z->lengthOf()}));
}
mmul(xT, yT, zT, 1., 0.);
}
else { // rest cases - batched mmul
const int batchRank = xRank - 2;
std::vector<int> dimsToExclude(batchRank);
for(int i = 0; i < batchRank; ++i)
@ -340,4 +342,4 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
}
#endif
#endif

View File

@ -473,19 +473,9 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
// FIXME: get rid of memcpy here
memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
for (int i = 0; i < minRank; ++i)
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
// nullify zero axis
for (int e = 0; e < maxRank; e++)
if (maxShapeInfo[e+1] == 0)
tmpShapeInfo[e+1] = 0;
int delta = maxRank - minRank;
for (int e = minRank - 1; e >= 0; e--)
if (minShapeInfo[e + 1] == 0)
tmpShapeInfo[e + 1 + delta] = 0;
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) {

View File

@ -40,7 +40,7 @@ namespace nd4j {
#ifdef __CUDACC__
__host__
#endif
void Logger::printv(const char *format, std::vector<int>& vec) {
void Logger::printv(const char *format, const std::vector<int>& vec) {
printf("%s: {", format);
for(int e = 0; e < vec.size(); e++) {
auto v = vec[e];
@ -55,7 +55,7 @@ namespace nd4j {
#ifdef __CUDACC__
__host__
#endif
void Logger::printv(const char *format, std::vector<Nd4jLong>& vec) {
void Logger::printv(const char *format, const std::vector<Nd4jLong>& vec) {
printf("%s: {", format);
for(int e = 0; e < vec.size(); e++) {
auto v = vec[e];

View File

@ -55,8 +55,8 @@ namespace nd4j {
static void _CUDA_H info(const char *format, ...);
static void _CUDA_H printv(const char *format, std::vector<int>& vec);
static void _CUDA_H printv(const char *format, std::vector<Nd4jLong>& vec);
static void _CUDA_H printv(const char *format, const std::vector<int>& vec);
static void _CUDA_H printv(const char *format, const std::vector<Nd4jLong>& vec);
};
}

View File

@ -1023,23 +1023,6 @@ namespace shape {
*/
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
/**
* insert dimension at shape[axis] position
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, dimension = 10 result is -> shape = {2,10,4,5}
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 3, dimension = 10 result is -> shape = {2,4,5,10}
* so be careful and provide shape buffer with enough (at least rank+1) length
* axis should be within [0, rank] range
*/
ND4J_EXPORT _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension);
/**
* erase dimension at shape[axis] position
* 1) for example: for given rank = 3, shape = {2,4,5}, axis = 1, result is -> shape = {2,5}
* 2) for example: for given rank = 3, shape = {2,4,5}, axis = 2, result is -> shape = {2,4}
* axis should be within [0, rank-1] range
*/
ND4J_EXPORT _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis);
@ -4932,21 +4915,6 @@ INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffs
}
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void insertDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis, const Nd4jLong dimension) {
for (int i = rank; i > axis; --i)
shape[i] = shape[i - 1];
shape[axis] = dimension;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void eraseDimension(const int rank, Nd4jLong *shape, const Nd4jLong axis) {
for (int i = axis; i < rank - 1; ++i)
shape[i] = shape[i + 1];
}
}

View File

@ -244,8 +244,9 @@ namespace functions {
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
for (Nd4jLong i = 0; i < ulen; i++)
for (Nd4jLong i = 0; i < ulen; i++) {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
}
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);

View File

@ -122,7 +122,7 @@ namespace functions {
tadLength = shape::length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength;
numTads = shape::length(yShapeInfo) / tadLength;
xEWS = shape::elementWiseStride(xShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}

View File

@ -21,12 +21,165 @@
#include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int half = window>>1;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
}
__syncthreads();
//for (int i = 0; i < length; i+= window)
/*
if window == 4;
iterations will be: 0; 4; 8; 12; 16; 20
if gridDim = 3;
on first iteration we'll have: 0; 4; 8;
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
*/
int firstPosition;
int firstStep;
int secondPosition;
int secondStep;
int WARP_SIZE = 32;
int numWarps = (gridDim.x * blockDim.x) / 32;
int warpId = tid / WARP_SIZE;
int warpIdx = tid % WARP_SIZE;
if (half >= 128) {
firstPosition = blockIdx.x * window;
firstStep = gridDim.x * window;
secondPosition = threadIdx.x;
secondStep = blockDim.x;
} else if (half >= 32) {
firstPosition = warpId * window;
firstStep = numWarps * window;
secondPosition = warpIdx;
secondStep = WARP_SIZE;
} else {
firstPosition = tid * window;
firstStep = blockDim.x * gridDim.x * window;
secondPosition = 0;
secondStep = 1;
}
for (int i = firstPosition; i < length; i += firstStep) {
for (int j = secondPosition; j < half; j += secondStep) {
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, yShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, yShapeInfo, xLength);
Y v0 = y[posIJ];
Y v1 = y[posIT];
if(!descending == (v0 > v1)) {
y[posIJ] = v1;
y[posIT] = v0;
X xtemp = x[posIJ];
x[posIJ] = x[posIT];
x[posIT] = xtemp;
}
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int half = window>>1;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
}
__syncthreads();
//for (int i = 0; i < length; i+= window)
/*
if window == 4;
iterations will be: 0; 4; 8; 12; 16; 20
if gridDim = 3;
on first iteration we'll have: 0; 4; 8;
on second iteration we'll have: 0 + (3 * 4) = 12; 4 + (3 * 4) = 16; 8 + (3 * 4) = 20
*/
int firstPosition;
int firstStep;
int secondPosition;
int secondStep;
int WARP_SIZE = 32;
int numWarps = (gridDim.x * blockDim.x) / 32;
int warpId = tid / WARP_SIZE;
int warpIdx = tid % WARP_SIZE;
if (half >= 128) {
firstPosition = blockIdx.x * window;
firstStep = gridDim.x * window;
secondPosition = threadIdx.x;
secondStep = blockDim.x;
} else if (half >= 32) {
firstPosition = warpId * window;
firstStep = numWarps * window;
secondPosition = warpIdx;
secondStep = WARP_SIZE;
} else {
firstPosition = tid * window;
firstStep = blockDim.x * gridDim.x * window;
secondPosition = 0;
secondStep = 1;
}
for (int i = firstPosition; i < length; i += firstStep) {
for (int j = secondPosition; j < half; j += secondStep) {
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
X v0 = x[posIJ];
X v1 = x[posIT];
if(!descending == (v0 > v1)) {
x[posIJ] = v1;
x[posIT] = v0;
Y ytemp = y[posIJ];
y[posIJ] = y[posIT];
y[posIT] = ytemp;
}
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__device__
void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
auto x = static_cast<T*>(vx);
int tid = threadIdx.x + blockDim.x * blockIdx.x;
@ -85,8 +238,8 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j;
if (it < length && ij < length ) {
int posIT = getDevicePosition(xShapeInfo,it, xLength);
int posIJ = getDevicePosition(xShapeInfo, ij, xLength);
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength);
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength);
shmem[threadIdx.x] = x[posIJ];
shmem[threadIdx.x + blockDim.x] = x[posIT];
@ -100,18 +253,22 @@ void bitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernel<T>(vx, xShapeInfo, window, length, reverse, descending);
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__host__ void bitonicArbitraryStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending) {
execBitonicArbitraryStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, window, length, reverse, descending);
nd4j::DebugHelper::checkErrorCode(stream, "bitonicArbitrary(...) failed");
}
template <typename X, typename Y>
__host__ void bitonicArbitraryStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
template <typename X, typename Y>
__host__ void bitonicArbitraryStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending) {
bitonicArbitraryStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, window, length, reverse, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicArbitraryStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int window, int length, int reverse, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -21,9 +21,119 @@
#include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicSortStepKernelValue(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
unsigned int i, ixj; /* Sorting partners: i and ixj */
i = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0)
xLength = shape::length(xShapeInfo);
__syncthreads();
if (i >= length)
return;
ixj = i^j;
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = shape::getIndexOffset(i, yShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, yShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
if (!descending == (y[posI]>y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
} else if ((i&k)!=0) {
/* Sort descending */
if (!descending == (y[posI]<y[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
unsigned int i, ixj; /* Sorting partners: i and ixj */
i = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ Nd4jLong xLength;
if (threadIdx.x == 0)
xLength = shape::length(xShapeInfo);
__syncthreads();
if (i >= length)
return;
ixj = i^j;
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
if (!descending == (x[posI]>x[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
} else if ((i&k)!=0) {
/* Sort descending */
if (!descending == (x[posI]<x[posIXJ])) {
/* exchange(i,ixj); */
X temp = x[posI];
x[posI] = x[posIXJ];
x[posIXJ] = temp;
Y ytemp = y[posI];
y[posI] = y[posIXJ];
y[posIXJ] = ytemp;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
__global__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
auto x = static_cast<T*>(vx);
@ -44,8 +154,8 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
/* The threads with the lowest ids sort the array. */
if ((ixj)>i) {
int posI = getDevicePosition(xShapeInfo, i, xLength);
int posIXJ = getDevicePosition(xShapeInfo, ixj, xLength);
int posI = shape::getIndexOffset(i, xShapeInfo, xLength);
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength);
if ((i&k)==0) {
/* Sort ascending */
@ -69,16 +179,23 @@ __device__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void execBitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernel<T>(vx, xShapeInfo, j, k, length, descending);
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__host__ void bitonicSortStepGeneric(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending) {
execBitonicSortStepKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, j, k, length, descending);
nd4j::DebugHelper::checkErrorCode(stream, "bitonicSortStep(...) failed");
template <typename X, typename Y>
__host__ void bitonicSortStepGenericKey(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__host__ void bitonicSortStepGenericValue(dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending) {
bitonicSortStepKernelValue<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, j, k, length, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT bitonicSortStepGenericValue, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int j, int k, int length, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -16,18 +16,89 @@
//
// @author raver119@gmail.com
// @author Yurii Shyrma, created on 28.11.2018
//
#include <ops/specials_cuda.h>
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Y>
__global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
auto x = static_cast<X*>(vx);
auto y = static_cast<Y*>(vy);
__shared__ int xLength;
__shared__ int xTadLength;
__shared__ int numTads;
if (threadIdx.x == 0) {
xLength = shape::length(xShapeInfo);
xTadLength = shape::length(tadShapeInfo);
numTads = xLength / xTadLength;
}
__syncthreads();
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
auto dx = x + tadOffsets[r];
auto dy = y + tadOffsets[r];
// this is general loop, we go uncached
int iterations = xTadLength;
for (int i = 0; i < iterations; i++) {
if (i % 2 == 0) {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1;
if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0];
dx[t0] = dx[t1];
dx[t1] = dt0;
Y dy0 = dy[t0];
dy[t0] = dy[t1];
dy[t1] = dy0;
}
}
}
} else {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2;
if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0];
dx[t0] = dx[t1];
dx[t1] = dt0;
Y dy0 = dy[t0];
dy[t0] = dy[t1];
dy[t1] = dy0;
}
}
}
}
__syncthreads();
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__device__
void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
auto x = static_cast<T*>(vx);
const int sharedSize = 32768;
@ -56,7 +127,7 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int iterations = xTadLength;
if (cached) {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
shmem[tid] = dx[t0];
}
@ -70,8 +141,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1;
if (top < xTadLength) {
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0];
@ -84,8 +155,8 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2;
if (top < xTadLength) {
auto t0 = cached ? top - 1 : getDevicePosition(tadShapeInfo, top - 1, xTadLength);
auto t1 = cached ? top : getDevicePosition(tadShapeInfo, top, xTadLength);
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength);
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength);
if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0];
@ -102,32 +173,34 @@ void oesTadKernel(void *vx, Nd4jLong *xShapeInfo,
if (cached) {
dx = x + tadOffsets[r];
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = getDevicePosition(tadShapeInfo, tid, xTadLength);
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength);
dx[t0] = shmem[tid];
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
oesTadKernel<T>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__host__ void oesTadGeneric(dim3 &launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
execOesTadKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
nd4j::DebugHelper::checkErrorCode(stream, "oesTad(...) failed");
}
template <typename X, typename Y>
__host__ void oesTadGenericKey(dim3 &launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool descending) {
execOesTadKernelKey<X,Y><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending);
}
BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT oesTadGeneric, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES);
BUILD_DOUBLE_TEMPLATE(template void ND4J_EXPORT oesTadGenericKey, (dim3 &launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool descending), LIBND4J_TYPES, LIBND4J_TYPES);

View File

@ -37,7 +37,7 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
std::vector<int> sharedAxes = *block.getIArguments();
const int inputRank = input->rankOf();
const int alphaRank = alpha->rankOf();
const int numSharedAxes = sharedAxes.size(); // can be zero as well
@ -49,12 +49,12 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
//***** input validation *****//
std::vector<Nd4jLong> expectedAlphaShape(&inputShape[1], &inputShape[inputRank]);
REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
REQUIRE_TRUE(inputRank > 1, 0, "PRELU OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
for(int i = 0; i < numSharedAxes; ++i) {
if(sharedAxes[i] <= 0)
sharedAxes[i] += inputRank - 1;
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
expectedAlphaShape[sharedAxes[i] - 1] = 1;
}
@ -65,14 +65,8 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
REQUIRE_TRUE(product == alphaLen, 0, "PRELU OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
// ***** end of validation ***** //
if(alphaShape != expectedAlphaShape)
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
helpers::prelu(block.launchContext(), *input, alphaShape != expectedAlphaShape ? alpha->reshape(alpha->ordering(), expectedAlphaShape) : *alpha, *output);
helpers::prelu(block.launchContext(), *input, *alpha, *output);
if(alphaShape != expectedAlphaShape)
delete alpha;
return Status::OK();
}
@ -90,12 +84,12 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto alpha = INPUT_VARIABLE(1);
auto dLdO = INPUT_VARIABLE(2);
auto dLdI = OUTPUT_VARIABLE(0);
auto dLdA = OUTPUT_VARIABLE(1);
std::vector<int> sharedAxes = *block.getIArguments();
const int inputRank = input->rankOf();
const int alphaRank = alpha->rankOf();
const int numSharedAxes = sharedAxes.size(); // can be zero as well
@ -105,19 +99,19 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
const std::vector<Nd4jLong> alphaShape = alpha->getShapeAsVector();
//***** input validation *****//
// temporary limitation imposed by Yurii
REQUIRE_TRUE(inputRank <= MAX_RANK/2, 0, "rank of input array should be <= MAX_RANK/2, but got %i instead!", inputRank);
REQUIRE_TRUE(input->lengthOf() / alpha->lengthOf() <= MAX_RANK*2, 0, "the length of input array should be no more than MAX_RANK*2 times the alpha array length, but got %lld and %lld correspondingly!", input->lengthOf(), alpha->lengthOf());
std::vector<Nd4jLong> expectedAlphaShape(&inputShape[1], &inputShape[inputRank]);
REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
REQUIRE_TRUE(inputRank > 1, 0, "PRELU_BP OP: wrong rank of input array, expected rank should be > 1, but got %i instead !", inputRank);
for(int i = 0; i < numSharedAxes; ++i) {
if(sharedAxes[i] <= 0)
sharedAxes[i] += inputRank - 1;
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
REQUIRE_TRUE(1 <= sharedAxes[i] && sharedAxes[i] <= inputRank - 1, 0, "PRELU_BP OP: wrong axis value %i in sharedAxes at position %i, axis value must be within range [1, input_rank-1] !", sharedAxes[i], i);
expectedAlphaShape[sharedAxes[i] - 1] = 1;
}
@ -127,19 +121,20 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
REQUIRE_TRUE(product == alphaLen, 0, "PRELU_BP OP: wrong shape of alpha array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedAlphaShape).c_str(), ShapeUtils::shapeAsString(alphaShape).c_str());
// ***** end of validation ***** //
if(alphaShape != expectedAlphaShape) {
alpha = alpha->reshape(alpha->ordering(), expectedAlphaShape);
dLdA = dLdA->reshape(dLdA->ordering(), expectedAlphaShape);
alpha = new NDArray(alpha->reshape(alpha->ordering(), expectedAlphaShape));
dLdA = new NDArray(dLdA->reshape(dLdA->ordering(), expectedAlphaShape));
}
helpers::preluBP(block.launchContext(), *input, *alpha, *dLdO, *dLdI, *dLdA);
if(alphaShape != expectedAlphaShape) {
if(alphaShape != expectedAlphaShape) {
delete alpha;
delete dLdA;
}
return Status::OK();
}

View File

@ -29,7 +29,6 @@ namespace nd4j {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
nd4j_printf("Comparing [%f] to [%f]\n", x->e<float>(0), y->e<float>(0));
if (x->e<float>(0) < y->e<float>(0))
return ND4J_STATUS_TRUE;
else

View File

@ -31,7 +31,7 @@ namespace nd4j {
auto condition = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (z->isEmpty())
return ND4J_STATUS_OK;
return Status::OK();
if (block.width() == 3) {
auto x = INPUT_VARIABLE(1);
@ -44,12 +44,10 @@ namespace nd4j {
// FIXME: for perf it might be better to issue memcpy here, and fill only mismatched values from either X or Y
for (int e = 0; e < condition->lengthOf(); e++) {
if (y->isR()) {
auto r = !condition->e<bool>(e) ? y->e<double>(e)
: x->e<double>(e);
auto r = !condition->e<bool>(e) ? y->e<double>(e) : x->e<double>(e);
z->p(e, r);
} else {
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e)
: x->e<Nd4jLong>(e);
auto r = !condition->e<bool>(e) ? y->e<Nd4jLong>(e) : x->e<Nd4jLong>(e);
z->p(e, r);
}
}
@ -86,7 +84,7 @@ namespace nd4j {
helpers::_where(block.launchContext(), *condition, *output, block.workspace());
}
return ND4J_STATUS_OK;
return Status::OK();
}
DECLARE_SHAPE_FN(Where) {

View File

@ -120,7 +120,7 @@ namespace nd4j {
}
}
return ND4J_STATUS_OK;
return Status::OK();
}
DECLARE_SHAPE_FN(where_np) {

View File

@ -81,11 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete outputReshaped;
delete weightsReshaped;
ConvolutionUtils::conv2d(block, &inputReshaped, &weightsReshaped, bias, &outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
return Status::OK();
}
@ -217,13 +213,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete gradIReshaped;
delete gradOReshaped;
delete weightsReshaped;
delete gradWReshaped;
ConvolutionUtils::conv2dBP(block, &inputReshaped, &weightsReshaped, bias, &gradOReshaped, &gradIReshaped, &gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
return Status::OK();
}

View File

@ -34,7 +34,7 @@ using namespace mkldnn;
#endif
CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@ -42,7 +42,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
@ -151,10 +151,10 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
std::vector<int> permutForOutput;
if(!isNCDHW)
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
else
if (isNCDHW)
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
else
input = new NDArray(input->permute({0,4,1,2,3}));
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
@ -164,9 +164,9 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
if(bias)
output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
if(!isNCDHW)
delete input;
if(!isNCDHW)
delete input;
return Status::OK();
}
@ -202,36 +202,36 @@ DECLARE_SHAPE_FN(conv3dnew) {
const int rank = 5;
REQUIRE_TRUE(inputShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of input array must be equal to %i, but got %i instead !", rank, inputShapeInfo);
REQUIRE_TRUE(weightsShapeInfo[0] == rank, 0, "CUSTOM CONV3D OP: rank of weights array must be equal to %i, but got %i instead !", rank, weightsShapeInfo);
int indIOioC, indIiD, indWoC(4);
if(!isNCDHW) {
indIOioC = 4; indIiD = 1;
}
else {
else {
indIOioC = 1; indIiD = 2;
}
}
int bS = inputShapeInfo[1]; // batch size
int iD = inputShapeInfo[indIiD+1]; // input depth
int iH = inputShapeInfo[indIiD+2]; // input height
int iW = inputShapeInfo[indIiD+3]; // input width
int iC = inputShapeInfo[indIOioC+1]; // input channels
int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
int oD, oH, oW; // output depth, height, width
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
Nd4jLong* outputShapeInfo = nullptr;
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
outputShapeInfo[0] = rank;
outputShapeInfo[1] = bS;
if (isNCDHW) {
if (isNCDHW) {
outputShapeInfo[2] = oC;
outputShapeInfo[3] = oD;
outputShapeInfo[4] = oH;
@ -242,7 +242,7 @@ DECLARE_SHAPE_FN(conv3dnew) {
outputShapeInfo[4] = oW;
outputShapeInfo[5] = oC;
}
ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo));
return SHAPELIST(CONSTANT(outputShapeInfo));
@ -251,12 +251,12 @@ DECLARE_SHAPE_FN(conv3dnew) {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
@ -291,12 +291,12 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(isSameMode) // SAME
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
#ifdef HAVE_MKLDNN
if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, weights, bias, gradO, gradI, gradW, gradB})) {
std::vector<nd4j::MKLDNNStream>& streams = block.getMKLDNNStreams();
@ -447,35 +447,37 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
std::vector<int> gradOaxesForDot;
if(!isNDHWC) {
input = input->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0,4,1,2,3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradOaxesForDot = {0,1,2,3}; // bS, oD, oH, oW
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
}
else
else {
gradOaxesForDot = {0,2,3,4}; // bS, oD, oH, oW
}
// ----- calculation of gradW and gradB ----- //
// ----- calculation of gradW and gradB ----- //
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
if(gradB) {
if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
//----- calculation of gradO -----//
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2))
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;
}
//----- calculation of gradI -----//
//----- calculation of gradI -----//
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
if(!isNDHWC) {
delete input;
delete input;
delete gradI;
}
return Status::OK();
}
@ -520,15 +522,15 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
if(!isNDHWC) {
indIOioC = 4; indIiD = 1;
}
else {
else {
indIOioC = 1; indIiD = 2;
}
}
int bS = inputShapeInfo[1]; // batch size
int iD = inputShapeInfo[indIiD+1]; // input depth
int iH = inputShapeInfo[indIiD+2]; // input height
int iW = inputShapeInfo[indIiD+3]; // input width
int iC = inputShapeInfo[indIOioC+1]; // input channels
int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels
int trueoD, trueoH, trueoW; // true output depth/height/width
@ -538,7 +540,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
auto gradIshapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
@ -547,7 +549,7 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
if(biasShapeInfo) {
auto gradBshapeInfo = ShapeBuilders::copyShapeInfoAndType(biasShapeInfo, gradOShapeInfo, false, block.getWorkspace());
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo), CONSTANT(gradBshapeInfo));
}
}
return SHAPELIST(CONSTANT(gradIshapeInfo), CONSTANT(gradWshapeInfo));
}

View File

@ -33,7 +33,7 @@ namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(!isNCHW)
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
@ -77,14 +77,14 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 1, 0, 4, 5});
LaunchContext* ctx = block.launchContext();
helpers::col2im(*ctx, columns, *output, sH, sW, pH, pW, oH, oW, dH, dW); // [bS, oC, kH, kW, iH, iW] is de-convoluted to [bS, oC, oH, oW]
//----- add biases if required -----//
if(bias)
output->applyBroadcast(broadcast::Add, {1}, bias);
if(!isNCHW)
if(!isNCHW)
delete output;
return Status::OK();
}
DECLARE_TYPES(deconv2d) {
@ -135,7 +135,7 @@ DECLARE_SHAPE_FN(deconv2d) {
int oH, oW; // output height, width
ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
Nd4jLong outputShape[4];
outputShape[0] = bS;
@ -211,8 +211,9 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// -----prepare permutation arrays and axes for dot product ----- //
std::vector<int> inputAxesForDot;
if(!isNCHW) {
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
inputAxesForDot = {0, 1, 2}; // bS, iH, iW
}
else
@ -228,7 +229,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;
@ -237,7 +238,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
if(!isNCHW)
delete gradO;
return ND4J_STATUS_OK;
return Status::OK();
}
DECLARE_SHAPE_FN(deconv2d_bp) {

View File

@ -27,32 +27,32 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, oC, iC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM DECONV3D OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2));// filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int kD = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) depth
int kH = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) height
int kW = INT_ARG(2) > 0 ? INT_ARG(2) : static_cast<int>(weights->sizeAt(2)); // filter(kernel) width
int sD = INT_ARG(3); // strides depth
int sH = INT_ARG(4); // strides height
int sW = INT_ARG(5); // strides width
int pD = INT_ARG(6); // paddings depth
int pH = INT_ARG(7); // paddings height
int pW = INT_ARG(8); // paddings width
int dD = INT_ARG(9); // dilations depth
int dH = INT_ARG(10); // dilations height
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 0-SAME, 1-VALID
int isNCDHW = block.getIArguments()->size() > 13 ? !INT_ARG(13) : 1; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
@ -64,7 +64,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DECONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
if(!isNCDHW)
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
@ -76,14 +76,14 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
//----- add biases if required -----//
if(bias)
output->applyBroadcast(broadcast::Add,{1}, bias);
if(!isNCDHW)
delete output;
return Status::OK();
}
@ -123,17 +123,17 @@ DECLARE_SHAPE_FN(deconv3d) {
int indIOioC, indIiD, indWoC(3);
if(!isNCDHW) {
indIOioC = 4; indIiD = 1;
indIOioC = 4; indIiD = 1;
}
else {
else {
indIOioC = 1; indIiD = 2;
}
}
const int bS = inputShapeInfo[1]; // batch size
const int iD = inputShapeInfo[indIiD+1]; // input depth
const int iH = inputShapeInfo[indIiD+2]; // input height
const int iW = inputShapeInfo[indIiD+3]; // input width
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, oC, iC});
@ -143,7 +143,7 @@ DECLARE_SHAPE_FN(deconv3d) {
int oD, oH, oW; // output depth, height, width
ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
Nd4jLong* outputShapeInfo = nullptr;
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inputShapeInfo), Nd4jLong);
@ -161,7 +161,7 @@ DECLARE_SHAPE_FN(deconv3d) {
outputShapeInfo[4] = oW;
outputShapeInfo[5] = oC;
}
ShapeUtils::updateStridesAndType(outputShapeInfo, weightsShapeInfo, shape::order(inputShapeInfo));
return SHAPELIST(CONSTANT(outputShapeInfo));
@ -225,8 +225,9 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// -----prepare permutation arrays and axes for dot product ----- //
std::vector<int> inputAxesForDot;
if(!isNCDHW) {
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
inputAxesForDot = {0, 1, 2, 3}; // bS, iD, iH, iW
}
else
@ -240,7 +241,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()});
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;
@ -260,7 +261,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
->setAllowedInputTypes(3, {ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(deconv3d_bp) {
auto inputShapeInfo = inputShape->at(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
@ -292,15 +293,15 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
if(!isNCDHW) {
indIOioC = 4; indIiD = 1;
}
else {
else {
indIOioC = 1; indIiD = 2;
}
}
const int bS = inputShapeInfo[1]; // batch size
const int iD = inputShapeInfo[indIiD+1]; // input depth
const int iH = inputShapeInfo[indIiD+2]; // input height
const int iW = inputShapeInfo[indIiD+3]; // input width
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
int trueoD, trueoH, trueoW; // true output depth, height, width
@ -312,7 +313,7 @@ DECLARE_SHAPE_FN(deconv3d_bp) {
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM DECONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
auto gradIShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, gradOShapeInfo, false, block.getWorkspace());
auto gradWShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, gradOShapeInfo, false, block.getWorkspace());

View File

@ -71,7 +71,7 @@ namespace ops {
int pad_top = 0, pad_left = 0;
int out_rows = 0, out_cols = 0;
helpers::_dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
@ -112,7 +112,7 @@ namespace ops {
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(block.dataType());
return SHAPELIST(newShape);
}
int e = 1;
for (int cnt = 0;cnt < 4; cnt++)
rates[cnt] = INT_ARG(e++);
@ -126,7 +126,7 @@ namespace ops {
int pad_top = 0, pad_left = 0;
int out_rows = 0, out_cols = 0;
helpers::_dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());

View File

@ -59,21 +59,20 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
if (!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
if (isSameMode)
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
//output->printBuffer("output op");
if (!isNCHW) {
if(!isNCHW) {
delete input;
delete output;
}
@ -92,7 +91,7 @@ DECLARE_SYN(avgpool, avgpool2d);
}
DECLARE_SHAPE_FN(avgpool2d) {
auto inShape = inputShape->at(0);
auto shapeOf = shape::shapeOf(inShape);
@ -177,27 +176,28 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if(isSameMode) // SAME
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
// columns2d->addiColumnVector(gradOVector);
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
// *gradI /= kH*kW;
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
// *gradI /= kH*kW;
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
if(!isNCHW) {
@ -205,16 +205,13 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
delete gradI;
delete gradO;
}
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK();
}
DECLARE_SHAPE_FN(avgpool2d_bp) {
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "AVGPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "AVGPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);

View File

@ -30,10 +30,10 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
int extraParam0 = INT_ARG(13);
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
@ -61,21 +61,21 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
//T extraParams[] = {};
//T extraParams[] = {};
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) {
if(!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
@ -103,22 +103,22 @@ DECLARE_SHAPE_FN(avgpool3dnew) {
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
auto inputShapeInfo = inputShape->at(0);
int idxID, idxIC;
int idxID, idxIC;
if(isNCDHW) { idxID = 2; idxIC = 1;}
else { idxID = 1; idxIC = 4;}
int bS = inputShapeInfo[1]; // batch size
int iC = inputShapeInfo[idxIC+1]; // input channels
int iC = inputShapeInfo[idxIC+1]; // input channels
int iD = inputShapeInfo[idxID+1]; // input depth
int iH = inputShapeInfo[idxID+2]; // input height
int iW = inputShapeInfo[idxID+3]; // input width
int oD, oH, oW; // output depth, height, width
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
Nd4jLong outputShape[5];
outputShape[0] = bS;
@ -146,7 +146,7 @@ DECLARE_SHAPE_FN(avgpool3dnew) {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
@ -164,10 +164,10 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
const int dH = INT_ARG(10); // dilations height
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
@ -180,22 +180,22 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) {
delete input;
delete gradI;
delete gradO;
}
}
return Status::OK();
}

View File

@ -59,10 +59,10 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
const int iH = isNCHW ? input->sizeAt(2) : input->sizeAt(1);
const int iW = isNCHW ? input->sizeAt(3) : input->sizeAt(2);
if (!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
@ -71,8 +71,8 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
if (!isNCHW) {
if(!isNCHW) {
delete input;
delete output;
}
@ -92,7 +92,7 @@ DECLARE_SYN(maxpool, maxpool2d);
DECLARE_SHAPE_FN(maxpool2d) {
//NDArray<T> *x = block.getVariables().at(0)->getNDArray();
Nd4jLong* inShape = inputShape->at(0);
Nd4jLong* shapeOf = shape::shapeOf(inShape);
@ -120,7 +120,7 @@ DECLARE_SHAPE_FN(maxpool2d) {
// calculate output Height/Width
int oH, oW;
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
// allocate memory for new shape
Nd4jLong newShape[4];
@ -175,27 +175,27 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
if(isSameMode) // SAME
if(isSameMode) // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
// input->template applyTransform<simdOps::Im2col<T>>(columns, std::vector<T>({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data());
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// columns2d->template applyTransform<simdOps::IsMax<T>>(std::vector<T>({(T)1., (T)1.}).data());
// columns2d->muliColumnVector(gradOVector);
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
if(!isNCHW) {
@ -203,17 +203,14 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
delete gradI;
delete gradO;
}
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK();
}
DECLARE_SYN(MaxPool2D_bp, maxpool2d_bp);
DECLARE_SYN(MaxPool_bp, maxpool2d_bp);
DECLARE_SHAPE_FN(maxpool2d_bp) {
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "MAXPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "MAXPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);

View File

@ -30,10 +30,10 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height
int kW = INT_ARG(2); // filter(kernel) width
@ -48,9 +48,9 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
@ -59,24 +59,24 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = output->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray(output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
if(!isNCDHW) {
if(!isNCDHW) {
delete input;
delete output;
}
return Status::OK();
}
@ -102,25 +102,25 @@ DECLARE_SHAPE_FN(maxpool3dnew) {
int dW = INT_ARG(11); // dilations width
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13);
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
Nd4jLong* inputShapeInfo = inputShape->at(0);
int idxID, idxIC;
int idxID, idxIC;
if(isNCDHW) { idxID = 2; idxIC = 1;}
else { idxID = 1; idxIC = 4;}
int bS = inputShapeInfo[1]; // batch size
int iC = inputShapeInfo[idxIC+1]; // input channels
int iC = inputShapeInfo[idxIC+1]; // input channels
int iD = inputShapeInfo[idxID+1]; // input depth
int iH = inputShapeInfo[idxID+2]; // input height
int iW = inputShapeInfo[idxID+3]; // input width
int oD, oH, oW; // output depth, height, width
ConvolutionUtils::calcOutSizePool3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, isSameMode);
Nd4jLong outputShape[5];
@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
const int dW = INT_ARG(11); // dilations width
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
@ -182,21 +182,21 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = input->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = gradI->permute({0, 4, 1, 2, 3}); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = gradO->permute({0, 4, 1, 2, 3}); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray(gradI->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray(gradO->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
}
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace());
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oD, oH, oW, kD, kH, kW}, input->getWorkspace());
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 5, 6, 7, 2, 3, 4}); // [bS, iC, oD, oH, oW, kD, kH, kW] -> [bS, iC, kD, kH, kW, oD, oH, oW]
// ConvolutionUtils<T>::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
// ConvolutionUtils<T>::vol2col(*input, *columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oD*oH*oW, kD*kH*kW});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// T extraParams[] = {(T)1., (T)1.};
// columns2d->template applyTransform<simdOps::IsMax<T>>(extraParams);
// columns2d->muliColumnVector(gradOVector);
@ -211,10 +211,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
delete gradI;
delete gradO;
}
// delete columns;
// delete columns2d;
// delete gradOVector;
return Status::OK();
}

View File

@ -52,11 +52,11 @@ namespace nd4j {
int oY = 0;
int oX = 0;
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // 1-NHWC, 0-NCHW
if (!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = output->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
const auto inY = static_cast<int>(input->sizeAt(2));
@ -70,7 +70,7 @@ namespace nd4j {
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
if (!isNCHW) {
if(!isNCHW) {
delete input;
delete output;
}
@ -175,40 +175,40 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = gradO->permute({0, 3, 1, 2}); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradI = new NDArray(gradI->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
gradO = new NDArray(gradO->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
}
// if(isSameMode) // SAME
// if(isSameMode) // SAME
// ConvolutionUtils<T>::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T> columnsWrongShape(input->ordering(), {bS, iC, oH, oW, kH, kW}, input->getWorkspace());
// NDArray<T>* columns = columnsWrongShape.permute({0, 1, 4, 5, 2, 3}); // [bS, iC, oH, oW, kH, kW] -> [bS, iC, kH, kW, oH, oW]
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* gradOVector = gradO->reshape('c', {(int) gradO->lengthOf(), 1});
// NDArray<T>* columns2d = columnsWrongShape.reshape('c', {bS*iC*oH*oW, kH*kW});
// NDArray<T> pNorm(columns2d->getShapeInfo(), block.getWorkspace());
// NDArray<T> pNorm(columns2d->getShapeInfo(), block.getWorkspace());
// input->template applyTransform<simdOps::Im2col<T>>(columns, std::vector<T>({(T)kH, (T)kW, (T)sH, (T)sW, (T)pH, (T)pW, (T)dH, (T)dW, (T)0.f, (T)0.f}).data());
// columns2d->template applyTransform<simdOps::Abs<T>>(&pNorm);
// pNorm.template applyTransform<simdOps::Pow<T>>(&pNorm, std::vector<T>({(T)pnorm}).data());
// NDArray<T>* denomVec = pNorm.sum({1});
// denomVec->template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)1. / pnorm}).data());
// denomVec->template applyScalar<simdOps::Max<T>>(eps); // in case of 0
// NDArray<T>* denomVec = pNorm.sum({1});
// denomVec->template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)1. / pnorm}).data());
// denomVec->template applyScalar<simdOps::Max<T>>(eps); // in case of 0
// denomVec->template applyPairwiseTransform<simdOps::ReverseDivide<T>>(gradOVector, denomVec, nullptr);
// if(pnorm != 2) {
// T extraParams[] = {(T)1. - (T)2. / pnorm};
// pNorm.template applyTransform<simdOps::Pow<T>>(std::vector<T>({(T)1. - (T)2. / pnorm}).data());
// *columns2d *= pNorm;
// }
// }
// columns2d->muliColumnVector(denomVec);
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
if(!isNCHW) {
@ -216,16 +216,12 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
delete gradI;
delete gradO;
}
// delete columns;
// delete columns2d;
// delete gradOVector;
// delete denomVec;
return Status::OK();
}
DECLARE_SHAPE_FN(pnormpool2d_bp) {
REQUIRE_TRUE(inputShape->at(0)[0] == 4, 0, "PNORMPOOL2D_BP op: input array must be 4D, but got %i instead!", inputShape->at(0)[0]);
REQUIRE_TRUE(inputShape->at(1)[0] == 4, 0, "PNORMPOOL2D_BP op: output's gradient array (next epsilon) must be 4D, but got %i instead!", inputShape->at(1)[0]);

View File

@ -29,7 +29,7 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
auto logits = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2);
@ -37,17 +37,17 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
double labelsSmoothing = T_ARG(0);
// input validation
// input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
// only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// smoothing is possible for rank of logits/labels > 1
REQUIRE_TRUE(labels->rankOf() > 1 || (labels->rankOf() == 1 && labelsSmoothing == 0.), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: smoothing is not possible when rank of labels/ logits = 1 !");
if(!output->isScalar()) {
// weights array can be single scalar or has the same shape as output, and must be broadcastable to output shape
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf());
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == output->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", weights->rankOf(), output->rankOf());
// check whether broadcast operation is possible for weights array
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
}
@ -59,8 +59,8 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
if(labelsSmoothing != 0.) {
newLabels = new NDArray(cLabels);
*newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1);
}
}
// main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension
// softmax_i = exp(logits_i) / sum_j(exp(logits_j))
// so result = sum_i( lables_i * (log(sum_j(exp(logits_j))) - logits_i) )
@ -73,24 +73,24 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(&E)) {
if(E.rankOf() == 1 && weights->isVector() && weights->rankOf() > 1)
weightsBroad = weights->reshape(weights->ordering(), {weights->lengthOf()});
weightsBroad = new NDArray(weights->reshape(weights->ordering(), {weights->lengthOf()}));
else
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
}
// multiply E on weights
// multiply E on weights
E *= *weightsBroad;
switch (reductionMode) {
case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels.
output->assign(&E);
break;
case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array
E.reduceNumber(reduce::Sum, *output);
break;
@ -99,12 +99,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
double sum;
if (weights->isScalar())
sum = weights->e<double>(0) * E.lengthOf();
else
else
sum = weightsBroad->reduceNumber(reduce::Sum).e<double>(0);
if (sum == 0.)
*output = 0.;
else
else
output->assign(E.reduceNumber(reduce::Sum) / sum);
break;
}
@ -132,15 +132,15 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
if(newLabels != cLabels)
delete newLabels;
delete cLabels;
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(softmax_cross_entropy_loss) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS})
@ -149,12 +149,12 @@ DECLARE_TYPES(softmax_cross_entropy_loss) {
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
auto logitsShapeInfo = inputShape->at(0);
auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2);
// labels and logits must have the same shapes
// labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
@ -165,14 +165,14 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
else { // in this case output has the shape as labels and logits minus last dimension
std::vector<int> dimensions = {-1};
outShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, true, block.getWorkspace());
// weights array can be single scalar or has the same rank as output, and must be broadcastable to output
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as output array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(outShapeInfo));
// check whether broadcast operation is possible for weights array
// check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, outShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(outShapeInfo).c_str());
}
return SHAPELIST(outShapeInfo);
return SHAPELIST(outShapeInfo);
}
@ -185,15 +185,15 @@ DECLARE_SHAPE_FN(softmax_cross_entropy_loss) {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
auto logits = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2);
auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits
auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights
auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels
auto labelsSmoothing = T_ARG(0);
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
@ -203,13 +203,13 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
std::vector<int> dimensions = {-1};
// input validation
// input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
// only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(logits->ordering(), dimensions, logits->getShapeInfo(), false, false, block.getWorkspace());
// weights array can be single scalar or has the same shape as loss, and must be broadcastable to loss shape
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo));
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", weights->rankOf(), shape::rank(lossShapeInfo));
// check whether broadcast operation is possible for weights array
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(weights->getShapeInfo(), lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
// smoothing is possible for rank of logits/labels > 1
@ -221,14 +221,14 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
auto newLabels = cLabels;
if(labelsSmoothing != 0.) {
newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext());
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
}
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true);
// dEdp = softmax * sum_i(lables_i) - labels
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels);
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels);
// dEdl = -log(softmax)
dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing));
@ -236,11 +236,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(&E))
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
weightsBroad = new NDArray(weights->tileToShape(E.getShapeInfo()));
dimensions = ShapeUtils::evalDimsToExclude(dLdp->rankOf(), dimensions);
@ -344,18 +344,18 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
if(weightsBroad != weights)
delete weightsBroad;
if(newLabels != cLabels)
delete newLabels;
delete newLabels;
delete cLabels;
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(softmax_cross_entropy_loss_grad) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS})
@ -367,27 +367,27 @@ DECLARE_TYPES(softmax_cross_entropy_loss_grad) {
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(softmax_cross_entropy_loss_grad) {
auto logitsShapeInfo = inputShape->at(0);
auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2);
std::vector<int> dimensions = {-1};
// labels and logits must have the same shapes
// labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace());
auto lossShapeInfo = ShapeUtils::evalReduceShapeInfo(shape::order(logitsShapeInfo), dimensions, logitsShapeInfo, false, false, block.getWorkspace());
// weights array can be single scalar or has the same rank as loss, and must be broadcastable to loss
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as loss array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(lossShapeInfo));
// check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, lossShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and loss arrays should be broadcastable, but got weights = %s and loss = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(lossShapeInfo).c_str());
auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo)));
auto dLdwShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(weightsShapeInfo), shape::shapeOf(weightsShapeInfo), shape::rank(weightsShapeInfo)));
auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
return SHAPELIST(dLdpShapeInfo, dLdwShapeInfo, dLdlShapeInfo);
}

View File

@ -74,7 +74,7 @@ namespace ops {
}
if(mask != nullptr){
NDArray* reshapedMask;
NDArray reshapedMask;
if(weights->rankOf() == 4){
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
}else{
@ -87,8 +87,7 @@ namespace ops {
// before going through the softmax, we effectively push all masked positions to zero after softmax.
//
// we are using 1e9 to mean effectively infinity
*weights += (*reshapedMask - 1) * 1e9;
delete reshapedMask;
*weights += (reshapedMask - 1) * 1e9;
}
nd4j::ops::softmax softmax;
@ -175,14 +174,13 @@ namespace ops {
preSoftmax /= factor;
if(mask != nullptr){
NDArray* reshapedMask;
NDArray reshapedMask;
if(preSoftmax.rankOf() == 4){
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), 1, mask->sizeAt(1), 1});
}else{
reshapedMask = mask->reshape(mask->ordering(), {mask->sizeAt(0), mask->sizeAt(1), 1});
}
preSoftmax += (*reshapedMask - 1) * 1e9;
delete reshapedMask;
preSoftmax += (reshapedMask - 1) * 1e9;
}
NDArray weights('c', weightShape, values->dataType(), block.launchContext());

View File

@ -70,7 +70,7 @@ namespace nd4j {
float beta = T_ARG(2);
int depth = INT_ARG(0);
helpers::lrnBP(*input, *gradO, *gradI, depth, bias, alpha, beta);
helpers::lrnBP(block, *input, *gradO, *gradI, depth, bias, alpha, beta);
return Status::OK();
}

View File

@ -98,9 +98,9 @@ namespace ops {
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
// Apply Attention
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
nd4j::ops::dot_product_attention attention;
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
// Project attention results
attnResults.permutei({0, 3, 1, 2});
@ -111,11 +111,9 @@ namespace ops {
mmul.execute({&attnResults, Wo},{&projRes}, {}, {}, {});
projRes.reshapei(projRes.ordering(), {miniBatchSize, queryCount, outSize});
projRes.permutei({0, 2, 1});
output->assign(projRes);
delete projectedQueries;
delete projectedKeys;
delete projectedValues;
// FIXME: bad for performance
output->assign(projRes);
return Status::OK();
}
@ -227,9 +225,9 @@ namespace ops {
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
// Apply Attention
NDArray attnResults('c', {projectedQueries->sizeAt(0), projectedValues->sizeAt(1), projectedValues->sizeAt(2), projectedQueries->sizeAt(3)}, projectedValues->dataType(), block.launchContext());
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
nd4j::ops::dot_product_attention attention;
attention.execute({projectedQueries, projectedKeys, projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults}, {}, {normalization, 0}, {});
// Project attention results
attnResults.permutei({0, 3, 1, 2});
@ -237,31 +235,25 @@ namespace ops {
// dLdWo
auto epsPerm = eps->permute({0, 2, 1});
auto epsPostReshape = epsPerm->reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
auto epsPostReshape = epsPerm.reshape(eps->ordering(), {miniBatchSize * queryCount, outSize});
nd4j::ops::matmul_bp matmulBp;
NDArray dLdPreWo(attnResults.shapeInfo(), false, block.launchContext());
matmulBp.execute({&attnResults, Wo, epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
matmulBp.execute({&attnResults, Wo, &epsPostReshape}, {&dLdPreWo, dLdWo}, {}, {}, {});
// dLdAttn
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues->sizeAt(2)});
dLdPreWo.reshapei({miniBatchSize, queryCount, numHeads, projectedValues.sizeAt(2)});
dLdPreWo.permutei({0, 2, 3, 1});
nd4j::ops::dot_product_attention_bp attentionBp;
NDArray dLdProjectedQueries(projectedQueries->shapeInfo(), false, block.launchContext());
NDArray dLdProjectedKeys(projectedKeys->shapeInfo(), false, block.launchContext());
NDArray dLdProjectedValues(projectedValues->shapeInfo(), false, block.launchContext());
attentionBp.execute({projectedQueries, projectedKeys, projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
NDArray dLdProjectedQueries(projectedQueries.shapeInfo(), false, block.launchContext());
NDArray dLdProjectedKeys(projectedKeys.shapeInfo(), false, block.launchContext());
NDArray dLdProjectedValues(projectedValues.shapeInfo(), false, block.launchContext());
attentionBp.execute({&projectedQueries, &projectedKeys, &projectedValues, &dLdPreWo, mask},{&dLdProjectedQueries, &dLdProjectedKeys, &dLdProjectedValues}, {}, {normalization}, {});
AttentionHelper::multiHeadProjectBp(queries, Wq, &dLdProjectedQueries, dLdq, dLdWq, block.launchContext());
AttentionHelper::multiHeadProjectBp(keys, Wk, &dLdProjectedKeys, dLdk, dLdWk, block.launchContext());
AttentionHelper::multiHeadProjectBp(values, Wv, &dLdProjectedValues, dLdv, dLdWv, block.launchContext());
delete projectedQueries;
delete projectedKeys;
delete projectedValues;
delete epsPerm;
delete epsPostReshape;
return Status::OK();
}

View File

@ -45,13 +45,13 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
int arrLen = a->lengthOf();
// FIXME: this stuff should be single op call. No sense rolling over couple of arrays twice
for(int i = 0; i < arrLen; ++i ) {
for(int i = 0; i < arrLen; ++i ) {
REQUIRE_TRUE(a->e<float>(i) > 0.f, 0, "BETAINC op: arrays a array must contain only elements > 0 !");
REQUIRE_TRUE(b->e<float>(i) > 0.f, 0, "BETAINC op: arrays b array must contain only elements > 0 !");
REQUIRE_TRUE(0.f <= x->e<float>(i) && x->e<float>(i) <= 1.f, 0, "BETAINC op: all elements of x array must be within [0, 1] range!");
}
*output = helpers::betaInc(block.launchContext(), *a, *b, *x);
helpers::betaInc(block.launchContext(), *a, *b, *x, *output);
return Status::OK();
}

View File

@ -48,10 +48,7 @@ namespace nd4j {
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
auto tArr = input->reshape(input->ordering(), shape);
auto zArr = z->reshape(z->ordering(), shape);
tArr->addRowVector(bias, zArr);
delete tArr;
delete zArr;
tArr.addRowVector(bias, &zArr);
}
STORE_RESULT(*z);
@ -87,13 +84,12 @@ namespace nd4j {
// cnn case
if (input->rankOf() == 4) {
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
epsilonNext2d->reshapei('c', {(int) bias->lengthOf(), -1});
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
auto sum = epsilonNext2d->reduceAlongDimension(reduce::Sum, {1});
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
gradB->assign(sum);
delete sum;
delete epsilonNext2d;
} else if (input->rankOf() == 2) {
// regular fully-connected case
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_check_numerics)
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(check_numerics, 2, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto message = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
auto allFinite = input->reduceNumber(reduce::BoolOps::IsFinite);
REQUIRE_TRUE(allFinite.e<bool>(0), 0, "CheckNumerics: %s", message->e<std::string>(0).c_str());
if (!block.isInplace())
output->assign(input);
return Status::OK();
}
DECLARE_SHAPE_FN(check_numerics) {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
}
DECLARE_TYPES(check_numerics) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, nd4j::DataType::UTF8)
->setAllowedOutputTypes({ALL_FLOATS});
}
}
}
#endif

View File

@ -56,7 +56,7 @@ namespace nd4j {
}
DECLARE_SHAPE_FN(crop_and_resize) {
auto in = inputShape->at(0);
auto in = inputShape->at(1);
Nd4jLong outputShape[4];
@ -77,8 +77,13 @@ namespace nd4j {
}
DECLARE_TYPES(crop_and_resize) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
// ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(1, {FLOAT32}) // as TF
->setAllowedInputTypes(2, {ALL_INTS})
->setAllowedInputTypes(3, {ALL_INTS})
->setAllowedOutputTypes({FLOAT32}); // as TF
// ->setAllowedOutputTypes({ALL_FLOATS});
}
}
}

View File

@ -47,9 +47,9 @@ namespace ops {
auto o = OUTPUT_VARIABLE(0);
if (a->lengthOf() == 3) {
helpers::_cross(block.launchContext(), a, b, o);
helpers::cross(block.launchContext(), a, b, o);
} else {
helpers::_crossBatched(block.launchContext(), a, b, o);
helpers::crossBatched(block.launchContext(), a, b, o);
}
return Status::OK();

Some files were not shown because too many files have changed in this diff Show More