From 455a5d112d2189f08f038752cf15b47ab8001f84 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 20 Apr 2020 03:27:13 +0300 Subject: [PATCH] Fixes for codegen generated classes and build improvements (#367) * Input format extended * Deleted redundant code * Added weights format to conv2d config * Refactoring * dl4j base test functionality * Different tests base class per module * Check base class for dl4j-graph subproject tests * Check if test classes extend BaseDL4JTest * Use nd4j-common-tests as transient dependency * Enums and tests added * Added codegenerated methods * Use namespace methods * Replace DifferentialFunctionFactory with codegenerated classes * Fixed linspace * Namespaces regenerated * Namespaces used instead of factory * Regenerated base classes * Input format extended * Added weights format to conv2d config * Refactoring * dl4j base test functionality * Different tests base class per module * Check base class for dl4j-graph subproject tests * Check if test classes extend BaseDL4JTest * Use nd4j-common-tests as transient dependency * Enums and tests added * Added codegenerated methods * Use namespace methods * Replace DifferentialFunctionFactory with codegenerated classes * Fixed linspace * Namespaces regenerated * Regenerated base classes * Regenerated namespaces * Generate nd4j namespaces * INDArrays accepting constructors * Generated some ops * Some fixes * SameDiff ops regenerated * Regenerated nd4j ops * externalErrors moved * Compilation fixes * SquaredDifference - strict number of args * Deprecated code cleanup. Proper base class for tests. * Extend test classes with BaseND4JTest * Extend test classes with BaseDL4JTest * Legacy code * DL4J cleanup * Exclude test utils from base class check * Tests fixed * Arbiter tests fix * Test dependency scope fix + pom.xml formatting Signed-off-by: Alex Black * Significant number of fixes Signed-off-by: Alex Black * Another round of fixes Signed-off-by: Alex Black * Another round of fixes Signed-off-by: Alex Black * Few additional fixes Signed-off-by: Alex Black * DataVec missing test scope dependencies Signed-off-by: Alex Black Co-authored-by: Alex Black --- arbiter/arbiter-core/pom.xml | 11 +- .../optimize/AssertTestsExtendBaseClass.java | 49 + .../arbiter/AssertTestsExtendBaseClass.java | 50 + .../server/AssertTestsExtendBaseClass.java | 50 + .../server/MnistDataSetIteratorFactory.java | 3 +- .../server/TestDataFactoryProviderMnist.java | 3 +- arbiter/arbiter-ui/pom.xml | 7 + .../optimize/AssertTestsExtendBaseClass.java | 50 + .../arbiter/optimize/TestBasic.java | 3 +- datavec/datavec-api/pom.xml | 11 +- .../api/util/ndarray/RecordConverter.java | 31 +- .../api/AssertTestsExtendBaseClass.java | 57 + .../impl/CSVLineSequenceRecordReaderTest.java | 3 +- .../CSVMultiSequenceRecordReaderTest.java | 4 +- .../CSVNLinesSequenceRecordReaderTest.java | 3 +- .../reader/impl/CSVRecordReaderTest.java | 3 +- .../impl/CSVSequenceRecordReaderTest.java | 3 +- ...VariableSlidingWindowRecordReaderTest.java | 3 +- .../impl/FileBatchRecordReaderTest.java | 3 +- .../reader/impl/FileRecordReaderTest.java | 3 +- .../impl/JacksonLineRecordReaderTest.java | 3 +- .../reader/impl/JacksonRecordReaderTest.java | 3 +- .../reader/impl/LibSvmRecordReaderTest.java | 3 +- .../records/reader/impl/LineReaderTest.java | 3 +- .../reader/impl/RegexRecordReaderTest.java | 3 +- .../reader/impl/SVMLightRecordReaderTest.java | 3 +- .../impl/TestCollectionRecordReaders.java | 3 +- .../impl/TestConcatenatingRecordReader.java | 3 +- .../reader/impl/TestSerialization.java | 3 +- .../TransformProcessRecordReaderTests.java | 3 +- .../writer/impl/CSVRecordWriterTest.java | 3 +- .../writer/impl/LibSvmRecordWriterTest.java | 3 +- .../writer/impl/SVMLightRecordWriterTest.java | 3 +- .../datavec/api/split/InputSplitTests.java | 3 +- .../split/NumberedFileInputSplitTests.java | 3 +- .../api/split/TestStreamInputSplit.java | 3 +- .../datavec/api/split/TransformSplitTest.java | 3 +- .../api/split/parittion/PartitionerTests.java | 3 +- .../api/transform/TestTransformProcess.java | 3 +- .../transform/condition/TestConditions.java | 3 +- .../api/transform/filter/TestFilters.java | 3 +- .../datavec/api/transform/join/TestJoin.java | 3 +- .../transform/ops/AggregableMultiOpTest.java | 3 +- .../transform/ops/AggregatorImplsTest.java | 3 +- .../api/transform/ops/DispatchOpTest.java | 3 +- .../transform/reduce/TestMultiOpReduce.java | 3 +- .../api/transform/reduce/TestReductions.java | 3 +- .../api/transform/schema/TestJsonYaml.java | 3 +- .../transform/schema/TestSchemaMethods.java | 3 +- .../TestReduceSequenceByWindowFunction.java | 3 +- .../transform/sequence/TestSequenceSplit.java | 3 +- .../sequence/TestWindowFunctions.java | 3 +- .../serde/TestCustomTransformJsonYaml.java | 3 +- .../transform/serde/TestYamlJsonSerde.java | 3 +- .../transform/stringreduce/TestReduce.java | 3 +- .../transform/RegressionTestJson.java | 3 +- .../api/transform/transform/TestJsonYaml.java | 3 +- .../transform/transform/TestTransforms.java | 3 +- .../TestNDArrayWritableTransforms.java | 3 +- .../transform/ndarray/TestYamlJsonSerde.java | 3 +- .../parse/ParseDoubleTransformTest.java | 3 +- .../org/datavec/api/transform/ui/TestUI.java | 3 +- .../api/util/ClassPathResourceTest.java | 3 +- .../datavec/api/util/TimeSeriesUtilsTest.java | 3 +- .../api/writable/RecordConverterTest.java | 7 +- .../TestNDArrayWritableAndSerialization.java | 3 +- .../datavec/api/writable/WritableTest.java | 5 +- datavec/datavec-arrow/pom.xml | 6 + .../org/datavec/arrow/ArrowConverterTest.java | 3 +- .../arrow/AssertTestsExtendBaseClass.java | 50 + .../org/datavec/arrow/RecordMapperTest.java | 3 +- ...rowWritableRecordTimeSeriesBatchTests.java | 3 +- .../datavec-data/datavec-data-audio/pom.xml | 7 + .../audio/AssertTestsExtendBaseClass.java | 55 + .../org/datavec/audio/AudioReaderTest.java | 3 +- .../audio/TestFastFourierTransform.java | 3 +- .../datavec-data/datavec-data-codec/pom.xml | 7 + .../reader/AssertTestsExtendBaseClass.java | 50 + .../datavec-data/datavec-data-image/pom.xml | 7 + .../image/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-data/datavec-data-nlp/pom.xml | 7 + .../nlp/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-data/datavec-geo/pom.xml | 9 +- .../transform/AssertTestsExtendBaseClass.java | 49 + datavec/datavec-data/datavec-hadoop/pom.xml | 7 + .../hadoop/AssertTestsExtendBaseClass.java | 48 + datavec/datavec-excel/pom.xml | 7 + .../poi/excel/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-jdbc/pom.xml | 7 + .../reader/AssertTestsExtendBaseClass.java | 49 + datavec/datavec-local/pom.xml | 7 + .../AssertTestsExtendBaseClass.java | 50 + .../transforms/analysis/TestAnalyzeLocal.java | 3 +- datavec/datavec-python/pom.xml | 7 + .../python/AssertTestsExtendBaseClass.java | 50 + .../datavec-spark-inference-client/pom.xml | 7 + .../client/AssertTestsExtendBaseClass.java | 49 + .../datavec-spark-inference-model/pom.xml | 7 + .../spark/transform/CSVSparkTransform.java | 3 +- .../transform/AssertTestsExtendBaseClass.java | 50 + .../datavec-spark-inference-server/pom.xml | 7 + .../transform/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-spark/pom.xml | 6 + .../spark/AssertTestsExtendBaseClass.java | 50 + .../spark/transform/NormalizationTests.java | 6 +- .../deeplearning4j-common-tests/pom.xml | 5 + deeplearning4j/deeplearning4j-core/pom.xml | 14 - .../AssertTestsExtendBaseClass.java | 54 +- .../CompareTrainingImplementations.java | 2 +- .../RecordReaderMultiDataSetIterator.java | 2 +- deeplearning4j/deeplearning4j-graph/pom.xml | 1 - .../graph/AssertTestsExtendedBaseClass.java | 49 + .../tokenizer/AssertTestsExtendBaseClass.java | 52 + .../test/java/AssertTestsExtendBaseClass.java | 53 + .../test/java/AssertTestsExtendBaseClass.java | 49 + .../AssertTestsExtendBaseClass.java | 49 + .../deeplearning4j-nlp/pom.xml | 125 +- .../AssertTestsExtendBaseClass.java | 49 + .../nn/conf/layers/LocallyConnected1D.java | 5 +- .../nn/conf/layers/LocallyConnected2D.java | 5 +- .../layers/samediff/SameDiffGraphVertex.java | 3 +- .../nn/layers/samediff/SameDiffLayer.java | 3 +- .../remote/AssertTestsExtendBaseClass.java | 48 + .../functions/DifferentialFunction.java | 15 +- .../DifferentialFunctionFactory.java | 2659 ----------------- .../nd4j/autodiff/samediff/SDVariable.java | 59 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 45 +- .../samediff/config/OutputConfig.java | 4 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 110 +- .../nd4j/autodiff/samediff/ops/SDImage.java | 8 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 674 ++++- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 66 +- .../org/nd4j/autodiff/samediff/ops/SDOps.java | 27 +- .../samediff/transform/OpPredicate.java | 1 - .../org/nd4j/autodiff/util/SameDiffUtils.java | 139 + .../org/nd4j/autodiff/util/TrainingUtils.java | 70 - .../src/main/java/org/nd4j/enums/PadMode.java | 29 + .../java/org/nd4j/enums/WeightsFormat.java | 29 + .../linalg/api/ops/BaseBroadcastBoolOp.java | 8 +- .../nd4j/linalg/api/ops/BaseBroadcastOp.java | 9 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 7 +- .../nd4j/linalg/api/ops/BaseScalarBoolOp.java | 3 +- .../org/nd4j/linalg/api/ops/BaseScalarOp.java | 3 +- .../nd4j/linalg/api/ops/BaseTransformOp.java | 11 +- .../java/org/nd4j/linalg/api/ops/NoOp.java | 5 + .../api/ops/impl/broadcast/BiasAdd.java | 2 +- .../impl/controlflow/compat/BaseCompatOp.java | 5 + .../ops/impl/controlflow/compat/Merge.java | 5 + .../impl/controlflow/compat/StopGradient.java | 2 +- .../ops/impl/controlflow/compat/Switch.java | 5 + .../linalg/api/ops/impl/indexaccum/IAMax.java | 2 +- .../linalg/api/ops/impl/indexaccum/IAMin.java | 2 +- .../linalg/api/ops/impl/indexaccum/IMax.java | 2 +- .../linalg/api/ops/impl/indexaccum/IMin.java | 2 +- .../ops/impl/layers/convolution/Conv2D.java | 6 +- .../ops/impl/layers/convolution/DeConv3D.java | 3 +- .../ops/impl/layers/convolution/Im2col.java | 2 +- .../impl/layers/convolution/Upsampling2d.java | 2 +- .../convolution/config/Conv2DConfig.java | 6 +- .../ops/impl/loss/AbsoluteDifferenceLoss.java | 4 +- .../api/ops/impl/loss/CosineDistanceLoss.java | 4 +- .../linalg/api/ops/impl/loss/HingeLoss.java | 4 +- .../linalg/api/ops/impl/loss/HuberLoss.java | 4 +- .../nd4j/linalg/api/ops/impl/loss/L2Loss.java | 2 +- .../linalg/api/ops/impl/loss/LogLoss.java | 4 +- .../api/ops/impl/loss/LogPoissonLoss.java | 10 +- .../loss/MeanPairwiseSquaredErrorLoss.java | 4 +- .../ops/impl/loss/MeanSquaredErrorLoss.java | 4 +- .../impl/loss/SigmoidCrossEntropyLoss.java | 4 +- .../impl/loss/SoftmaxCrossEntropyLoss.java | 4 +- .../SoftmaxCrossEntropyWithLogitsLoss.java | 5 +- ...arseSoftmaxCrossEntropyLossWithLogits.java | 6 +- .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 2 +- .../linalg/api/ops/impl/reduce/Moments.java | 6 +- .../api/ops/impl/reduce/TensorMmul.java | 18 +- .../linalg/api/ops/impl/reduce/bool/All.java | 2 +- .../linalg/api/ops/impl/reduce/bool/Any.java | 2 +- .../api/ops/impl/reduce/bool/IsInf.java | 2 +- .../api/ops/impl/reduce/bool/IsNaN.java | 2 +- .../api/ops/impl/reduce/custom/LogSumExp.java | 5 +- .../api/ops/impl/reduce/floating/AMean.java | 3 +- .../api/ops/impl/reduce/floating/Entropy.java | 11 +- .../ops/impl/reduce/floating/LogEntropy.java | 4 +- .../api/ops/impl/reduce/floating/Mean.java | 3 +- .../api/ops/impl/reduce/floating/Norm1.java | 3 +- .../api/ops/impl/reduce/floating/Norm2.java | 3 +- .../api/ops/impl/reduce/floating/NormMax.java | 3 +- .../impl/reduce/floating/ShannonEntropy.java | 7 +- .../ops/impl/reduce/floating/SquaredNorm.java | 3 +- .../ops/impl/reduce/longer/CountNonZero.java | 2 +- .../api/ops/impl/reduce/longer/CountZero.java | 2 +- .../linalg/api/ops/impl/reduce/same/AMax.java | 3 +- .../linalg/api/ops/impl/reduce/same/AMin.java | 3 +- .../linalg/api/ops/impl/reduce/same/ASum.java | 3 +- .../linalg/api/ops/impl/reduce/same/Max.java | 3 +- .../linalg/api/ops/impl/reduce/same/Min.java | 3 +- .../linalg/api/ops/impl/reduce/same/Prod.java | 3 +- .../linalg/api/ops/impl/reduce/same/Sum.java | 3 +- .../api/ops/impl/reduce3/CosineDistance.java | 4 +- .../ops/impl/reduce3/CosineSimilarity.java | 12 +- .../nd4j/linalg/api/ops/impl/reduce3/Dot.java | 3 +- .../ops/impl/reduce3/EuclideanDistance.java | 5 +- .../api/ops/impl/reduce3/JaccardDistance.java | 17 +- .../ops/impl/reduce3/ManhattanDistance.java | 5 +- .../linalg/api/ops/impl/scalar/LeakyReLU.java | 3 +- .../linalg/api/ops/impl/scalar/PRelu.java | 3 +- .../nd4j/linalg/api/ops/impl/scalar/Pow.java | 5 +- .../api/ops/impl/scalar/RectifiedLinear.java | 3 +- .../linalg/api/ops/impl/scalar/Relu6.java | 3 +- .../linalg/api/ops/impl/scalar/ScalarAdd.java | 6 +- .../impl/scalar/ScalarReverseDivision.java | 5 +- .../impl/scalar/ScalarReverseSubtraction.java | 5 +- .../linalg/api/ops/impl/scalar/ScalarSet.java | 2 +- .../nd4j/linalg/api/ops/impl/scalar/Step.java | 2 +- .../api/ops/impl/scatter/ScatterAdd.java | 4 +- .../api/ops/impl/scatter/ScatterDiv.java | 10 +- .../api/ops/impl/scatter/ScatterMax.java | 6 +- .../api/ops/impl/scatter/ScatterMin.java | 6 +- .../api/ops/impl/scatter/ScatterMul.java | 8 +- .../api/ops/impl/scatter/ScatterSub.java | 4 +- .../api/ops/impl/scatter/ScatterUpdate.java | 8 +- .../linalg/api/ops/impl/shape/Linspace.java | 2 +- .../linalg/api/ops/impl/shape/Permute.java | 4 +- .../linalg/api/ops/impl/shape/Reshape.java | 4 +- .../api/ops/impl/shape/SequenceMask.java | 2 +- .../linalg/api/ops/impl/shape/ShapeN.java | 2 +- .../nd4j/linalg/api/ops/impl/shape/Slice.java | 5 +- .../nd4j/linalg/api/ops/impl/shape/Stack.java | 2 +- .../api/ops/impl/shape/StridedSlice.java | 9 +- .../nd4j/linalg/api/ops/impl/shape/Tile.java | 5 +- .../impl/summarystats/StandardDeviation.java | 3 +- .../api/ops/impl/summarystats/Variance.java | 3 +- .../linalg/api/ops/impl/transforms/Angle.java | 2 +- .../linalg/api/ops/impl/transforms/Pad.java | 27 + .../api/ops/impl/transforms/any/IsMax.java | 2 +- .../ops/impl/transforms/bool/BooleanNot.java | 2 +- .../ops/impl/transforms/bool/IsFinite.java | 2 +- .../api/ops/impl/transforms/bool/IsInf.java | 2 +- .../api/ops/impl/transforms/bool/IsNaN.java | 2 +- .../ops/impl/transforms/clip/ClipByNorm.java | 2 +- .../ops/impl/transforms/clip/ClipByValue.java | 4 +- .../api/ops/impl/transforms/custom/ATan2.java | 12 +- .../ops/impl/transforms/custom/Assign.java | 2 +- .../ops/impl/transforms/custom/CumProd.java | 3 +- .../ops/impl/transforms/custom/CumSum.java | 3 +- .../custom/DotProductAttention.java | 3 +- .../transforms/custom/DynamicPartition.java | 3 +- .../impl/transforms/custom/DynamicStitch.java | 2 +- .../transforms/custom/InvertPermutation.java | 5 +- .../ops/impl/transforms/custom/LayerNorm.java | 10 +- .../impl/transforms/custom/LogSoftMax.java | 7 +- .../impl/transforms/custom/MatrixSetDiag.java | 4 +- .../custom/MultiHeadDotProductAttention.java | 2 +- .../api/ops/impl/transforms/custom/Pow.java | 4 +- .../ops/impl/transforms/custom/Reverse.java | 4 +- .../transforms/custom/ReverseSequence.java | 4 +- .../ops/impl/transforms/custom/SoftMax.java | 4 +- .../impl/transforms/custom/Standardize.java | 3 +- .../impl/transforms/custom/ThresholdRelu.java | 3 +- .../api/ops/impl/transforms/custom/Trace.java | 10 +- .../transforms/custom/segment/SegmentMax.java | 3 +- .../custom/segment/SegmentMean.java | 3 +- .../transforms/custom/segment/SegmentMin.java | 3 +- .../custom/segment/SegmentProd.java | 3 +- .../transforms/custom/segment/SegmentSum.java | 3 +- .../api/ops/impl/transforms/dtype/Cast.java | 2 +- .../ops/impl/transforms/floating/RSqrt.java | 2 +- .../transforms/gradient/SELUDerivative.java | 10 +- .../transforms/gradient/TanhDerivative.java | 2 +- .../transforms/pairwise/arithmetic/AddOp.java | 12 +- .../transforms/pairwise/arithmetic/DivOp.java | 13 +- .../pairwise/arithmetic/FModOp.java | 3 +- .../pairwise/arithmetic/FloorDivOp.java | 8 +- .../pairwise/arithmetic/FloorModOp.java | 8 +- .../pairwise/arithmetic/MergeAddOp.java | 2 +- .../transforms/pairwise/arithmetic/ModOp.java | 15 +- .../transforms/pairwise/arithmetic/MulOp.java | 14 +- .../pairwise/arithmetic/RDivOp.java | 13 +- .../pairwise/arithmetic/RSubOp.java | 17 +- .../pairwise/arithmetic/RealDivOp.java | 3 +- .../arithmetic/SquaredDifferenceOp.java | 18 +- .../transforms/pairwise/arithmetic/SubOp.java | 14 +- .../pairwise/arithmetic/TruncateDivOp.java | 4 +- .../impl/transforms/pairwise/bool/Not.java | 2 +- .../api/ops/impl/transforms/same/AMax.java | 3 +- .../api/ops/impl/transforms/same/AMin.java | 3 +- .../api/ops/impl/transforms/same/Abs.java | 2 +- .../api/ops/impl/transforms/same/Ceil.java | 2 +- .../api/ops/impl/transforms/same/Cube.java | 3 +- .../api/ops/impl/transforms/same/Max.java | 6 +- .../api/ops/impl/transforms/same/Min.java | 9 +- .../ops/impl/transforms/same/Negative.java | 2 +- .../ops/impl/transforms/same/Reciprocal.java | 2 +- .../api/ops/impl/transforms/same/Round.java | 2 +- .../api/ops/impl/transforms/same/Square.java | 6 +- .../segment/UnsortedSegmentMax.java | 3 +- .../segment/UnsortedSegmentMean.java | 3 +- .../segment/UnsortedSegmentMin.java | 3 +- .../segment/UnsortedSegmentProd.java | 3 +- .../segment/UnsortedSegmentSqrtN.java | 3 +- .../segment/UnsortedSegmentSum.java | 3 +- .../api/ops/impl/transforms/strict/ACos.java | 7 +- .../api/ops/impl/transforms/strict/ASinh.java | 4 +- .../api/ops/impl/transforms/strict/ATan.java | 5 +- .../api/ops/impl/transforms/strict/Cos.java | 2 +- .../api/ops/impl/transforms/strict/Cosh.java | 2 +- .../api/ops/impl/transforms/strict/ELU.java | 3 +- .../api/ops/impl/transforms/strict/Exp.java | 2 +- .../api/ops/impl/transforms/strict/Expm1.java | 2 +- .../api/ops/impl/transforms/strict/GELU.java | 2 +- .../impl/transforms/strict/HardSigmoid.java | 3 +- .../ops/impl/transforms/strict/HardTanh.java | 3 +- .../api/ops/impl/transforms/strict/Log.java | 6 +- .../api/ops/impl/transforms/strict/Log1p.java | 3 +- .../impl/transforms/strict/LogSigmoid.java | 7 +- .../api/ops/impl/transforms/strict/Mish.java | 5 +- .../impl/transforms/strict/PreciseGELU.java | 6 +- .../impl/transforms/strict/RationalTanh.java | 7 +- .../impl/transforms/strict/RectifiedTanh.java | 7 +- .../api/ops/impl/transforms/strict/SELU.java | 3 +- .../ops/impl/transforms/strict/Sigmoid.java | 4 +- .../api/ops/impl/transforms/strict/Sin.java | 5 +- .../api/ops/impl/transforms/strict/Sinh.java | 5 +- .../ops/impl/transforms/strict/SoftPlus.java | 5 +- .../ops/impl/transforms/strict/SoftSign.java | 7 +- .../api/ops/impl/transforms/strict/Swish.java | 2 +- .../transforms/strict/SwishDerivative.java | 4 +- .../api/ops/impl/transforms/strict/Tan.java | 2 +- .../api/ops/impl/transforms/strict/Tanh.java | 4 +- .../org/nd4j/linalg/factory/ops/NDBase.java | 53 +- .../org/nd4j/linalg/factory/ops/NDImage.java | 4 +- .../org/nd4j/linalg/factory/ops/NDMath.java | 316 +- .../org/nd4j/linalg/factory/ops/NDNN.java | 31 +- .../nd4j-tests-tensorflow/pom.xml | 13 +- nd4j/nd4j-backends/nd4j-tests/pom.xml | 23 +- .../org/nd4j/AssertTestsExtendBaseClass.java | 68 +- .../java/org/nd4j/autodiff/TestSessions.java | 6 +- .../opvalidation/MiscOpValidation.java | 22 +- .../opvalidation/ReductionOpValidation.java | 4 +- .../opvalidation/ShapeOpValidation.java | 14 +- .../opvalidation/TransformOpValidation.java | 25 +- .../samediff/FlatBufferSerdeTest.java | 4 +- .../autodiff/samediff/NameScopeTests.java | 2 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 119 +- .../samediff/SameDiffTrainingTest.java | 6 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 66 + nd4j/nd4j-common-tests/pom.xml | 16 + .../org/nd4j/AbstractAssertTestsClass.java | 82 + nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml | 9 +- .../pom.xml | 6 +- .../nd4j-parameter-server-status/pom.xml | 6 +- nd4j/nd4j-remote/nd4j-json-server/pom.xml | 342 +-- nd4j/nd4j-serde/nd4j-aeron/pom.xml | 329 +- nd4j/nd4j-serde/nd4j-arrow/pom.xml | 19 +- .../nd4j-camel-routes/nd4j-kafka/pom.xml | 3 +- nd4j/nd4j-serde/nd4j-gson/pom.xml | 28 +- nd4j/nd4j-serde/nd4j-kryo/pom.xml | 36 +- 358 files changed, 4531 insertions(+), 3919 deletions(-) create mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java create mode 100644 nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 6ce3c9c1f..76251f4cd 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -14,7 +14,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + arbiter org.deeplearning4j @@ -33,10 +34,10 @@ nd4j-api ${nd4j.version} - - com.google.code.findbugs - * - + + com.google.code.findbugs + * + diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..75a64d05f --- /dev/null +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b8d200350 --- /dev/null +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b305b123b --- /dev/null +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.server; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.server"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java index dbf05d34f..57bef758d 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.server; import lombok.Data; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @@ -27,7 +28,7 @@ import java.io.IOException; * Created by agibsonccc on 3/13/17. */ @Data -public class MnistDataSetIteratorFactory implements DataSetIteratorFactory { +public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory { /** * @return */ diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java index e1a7f820e..c4a75ffb4 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java @@ -17,13 +17,14 @@ package org.deeplearning4j.arbiter.server; import lombok.AllArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @AllArgsConstructor -public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { +public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory { private int batchSize; private int terminationIter; diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 2067a3fc7..7392392db 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -54,6 +54,13 @@ ${dl4j.version} + + org.deeplearning4j + deeplearning4j-common-tests + ${dl4j.version} + test + + ch.qos.logback logback-classic diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..fee20847c --- /dev/null +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.deeplearning4j.BaseDL4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 804c6f974..ddf73e455 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; @@ -70,7 +71,7 @@ import java.util.concurrent.TimeUnit; /** * Created by Alex on 19/07/2017. */ -public class TestBasic { +public class TestBasic extends BaseDL4JTest { @Test @Ignore diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml index 10ed3517a..3c3eec86e 100644 --- a/datavec/datavec-api/pom.xml +++ b/datavec/datavec-api/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + datavec-parent org.datavec @@ -79,6 +80,14 @@ ${nd4j.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + + ch.qos.logback logback-classic diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index c55d4d3bb..92a1f737b 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -47,9 +47,8 @@ public class RecordConverter { * * @return the array */ - @Deprecated - public static INDArray toArray(Collection record, int size) { - return toArray(record); + public static INDArray toArray(DataType dataType, Collection record, int size) { + return toArray(dataType, record); } /** @@ -78,13 +77,23 @@ public class RecordConverter { /** * Convert a set of records in to a matrix + * As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype * @param records the records ot convert * @return the matrix for the records */ public static INDArray toMatrix(List> records) { + return toMatrix(DataType.FLOAT, records); + } + + /** + * Convert a set of records in to a matrix + * @param records the records ot convert + * @return the matrix for the records + */ + public static INDArray toMatrix(DataType dataType, List> records) { List toStack = new ArrayList<>(); for(List l : records){ - toStack.add(toArray(l)); + toStack.add(toArray(dataType, l)); } return Nd4j.vstack(toStack); @@ -92,10 +101,20 @@ public class RecordConverter { /** * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype * @param record the record to convert * @return the array */ - public static INDArray toArray(Collection record) { + public static INDArray toArray(Collection record){ + return toArray(DataType.FLOAT, record); + } + + /** + * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * @param record the record to convert + * @return the array + */ + public static INDArray toArray(DataType dataType, Collection record) { List l; if(record instanceof List){ l = (List)record; @@ -124,7 +143,7 @@ public class RecordConverter { } } - INDArray arr = Nd4j.create(1, length); + INDArray arr = Nd4j.create(dataType, 1, length); int k = 0; for (Writable w : record ) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..43c606123 --- /dev/null +++ b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api; + +import lombok.extern.slf4j.Slf4j; +import org.datavec.api.transform.serde.testClasses.CustomCondition; +import org.datavec.api.transform.serde.testClasses.CustomFilter; +import org.datavec.api.transform.serde.testClasses.CustomTransform; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + Set> res = new HashSet<>(); + res.add(CustomCondition.class); + res.add(CustomFilter.class); + res.add(CustomTransform.class); + return res; + } + + @Override + protected String getPackageName() { + return "org.datavec.api"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 70a7ffa7b..84d9b259f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -25,6 +25,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.nio.charset.StandardCharsets; @@ -34,7 +35,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVLineSequenceRecordReaderTest { +public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index 888d8b523..c293d4544 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -26,6 +26,8 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.BaseCompatOp; import java.io.File; import java.nio.charset.StandardCharsets; @@ -37,7 +39,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class CSVMultiSequenceRecordReaderTest { +public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java index 9b84fddc3..9f297d83b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/09/2016. */ -public class CSVNLinesSequenceRecordReaderTest { +public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { @Test public void testCSVNLinesSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 534cc986e..471dc07c4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -44,7 +45,7 @@ import java.util.NoSuchElementException; import static org.junit.Assert.*; -public class CSVRecordReaderTest { +public class CSVRecordReaderTest extends BaseND4JTest { @Test public void testNext() throws Exception { CSVRecordReader reader = new CSVRecordReader(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index fbbd992d1..e0763bbbc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVSequenceRecordReaderTest { +public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java index 8e60acad9..fe0c94c4c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java @@ -22,6 +22,7 @@ import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordRea import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.LinkedList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; * * @author Justin Long (crockpotveggies) */ -public class CSVVariableSlidingWindowRecordReaderTest { +public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { @Test public void testCSVVariableSlidingWindowRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index c67e32192..d6f03d815 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.api.loader.FileBatch; import java.io.File; @@ -36,7 +37,7 @@ import java.util.List; import static org.junit.Assert.*; -public class FileBatchRecordReaderTest { +public class FileBatchRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java index 533f5be66..6bf66880f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java @@ -23,6 +23,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.net.URI; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse; /** * Created by nyghtowl on 11/14/15. */ -public class FileRecordReaderTest { +public class FileRecordReaderTest extends BaseND4JTest { @Test public void testReset() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index bfeadef36..2f91579f0 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class JacksonLineRecordReaderTest { +public class JacksonLineRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index c95de48e7..f1fa8d6b2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 11/04/2016. */ -public class JacksonRecordReaderTest { +public class JacksonRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java index 75871a6b7..5e8ca6546 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -44,7 +45,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordReaderTest { +public class LibSvmRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index 5027357eb..17a41f4d4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -29,6 +29,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/17/14. */ -public class LineReaderTest { +public class LineReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 778d14424..539b0f351 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -32,6 +32,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -45,7 +46,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 12/04/2016. */ -public class RegexRecordReaderTest { +public class RegexRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java index 92f8c57e4..25d2959ce 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -42,7 +43,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordReaderTest { +public class SVMLightRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index a06d56400..fa68c4a1f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -23,6 +23,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/05/2016. */ -public class TestCollectionRecordReaders { +public class TestCollectionRecordReaders extends BaseND4JTest { @Test public void testCollectionSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java index 172d884a3..266ad2edc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java @@ -20,11 +20,12 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import static org.junit.Assert.assertEquals; -public class TestConcatenatingRecordReader { +public class TestConcatenatingRecordReader extends BaseND4JTest { @Test public void test() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java index c249737a3..91fc22886 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java @@ -34,6 +34,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader * and the rest are generally used without being initialized on a particular dataset */ -public class TestSerialization { +public class TestSerialization extends BaseND4JTest { @Test public void testRR() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index 5daad01b3..ff3ceb9be 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 3/21/17. */ -public class TransformProcessRecordReaderTests { +public class TransformProcessRecordReaderTests extends BaseND4JTest { @Test public void simpleTransformTest() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java index 5a165b0ac..c3a8f4181 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ -public class CSVRecordWriterTest { +public class CSVRecordWriterTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index 0c7d70b09..91996056d 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordWriterTest { +public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index 7b9b8c203..f057c7d45 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -25,6 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.*; import org.datavec.api.writable.NDArrayWritable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -47,7 +48,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordWriterTest { +public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index e8ce37bd3..59e1feee8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; * * @author saudet */ -public class InputSplitTests { +public class InputSplitTests extends BaseND4JTest { @Test public void testSample() throws URISyntaxException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java index 797f546dd..f8be04d47 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java @@ -17,13 +17,14 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class NumberedFileInputSplitTests { +public class NumberedFileInputSplitTests extends BaseND4JTest { @Test public void testNumberedFileInputSplitBasic() { String baseString = "/path/to/files/prefix%d.suffix"; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index c618c625d..94119015c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.function.Function; import java.io.File; @@ -40,7 +41,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; -public class TestStreamInputSplit { +public class TestStreamInputSplit extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java index 457f07097..ea6b9fea4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java @@ -17,6 +17,7 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import java.net.URISyntaxException; @@ -28,7 +29,7 @@ import static org.junit.Assert.assertArrayEquals; /** * @author Ede Meijer */ -public class TransformSplitTest { +public class TransformSplitTest extends BaseND4JTest { @Test public void testTransform() throws URISyntaxException { Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index c9fb57eb9..f27f7527f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split.parittion; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; @@ -31,7 +32,7 @@ import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class PartitionerTests { +public class PartitionerTests extends BaseND4JTest { @Test public void testRecordsPerFilePartition() { Partitioner partitioner = new NumberOfRecordsPartitioner(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index eeb4be27a..efb9f2b6e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -26,12 +26,13 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import static org.junit.Assert.assertEquals; -public class TestTransformProcess { +public class TestTransformProcess extends BaseND4JTest { @Test public void testExecution(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index da4e53398..0c69959d6 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -24,6 +24,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 24/03/2016. */ -public class TestConditions { +public class TestConditions extends BaseND4JTest { @Test public void testIntegerCondition() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 4d96b5b6e..314ee72ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 21/03/2016. */ -public class TestFilters { +public class TestFilters extends BaseND4JTest { @Test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index e6ae74185..1d113c6ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -23,6 +23,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/04/2016. */ -public class TestJoin { +public class TestJoin extends BaseND4JTest { @Test public void testJoin() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index 059cb618c..57ec54e8a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.Serializable; import java.util.*; @@ -27,7 +28,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregableMultiOpTest { +public class AggregableMultiOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index 487926c7a..c722dada4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -19,6 +19,7 @@ package org.datavec.api.transform.ops; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregatorImplsTest { +public class AggregatorImplsTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index 076e4412d..a636e7239 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class DispatchOpTest { +public class DispatchOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index 1b0a20430..9aef39aa4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -29,6 +29,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -38,7 +39,7 @@ import static org.junit.Assert.fail; /** * Created by Alex on 21/03/2016. */ -public class TestMultiOpReduce { +public class TestMultiOpReduce extends BaseND4JTest { @Test public void testMultiOpReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index f9debfe2c..dc6443630 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -21,13 +21,14 @@ import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -public class TestReductions { +public class TestReductions extends BaseND4JTest { @Test public void testGeographicMidPointReduction(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index dff90f8b9..8e33b742c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -19,13 +19,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.metadata.ColumnMetaData; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java index 870c10680..6cbcafff4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java @@ -18,13 +18,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.ColumnType; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 04/09/2016. */ -public class TestSchemaMethods { +public class TestSchemaMethods extends BaseND4JTest { @Test public void testNumberedColumnAdding() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 0f48eeff3..56c8d3f1e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +42,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestReduceSequenceByWindowFunction { +public class TestReduceSequenceByWindowFunction extends BaseND4JTest { @Test public void testReduceSequenceByWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index 6695599a5..98dd49587 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/04/2016. */ -public class TestSequenceSplit { +public class TestSequenceSplit extends BaseND4JTest { @Test public void testSequenceSplitTimeSeparation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index 99fe9227d..cc12adc53 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestWindowFunctions { +public class TestWindowFunctions extends BaseND4JTest { @Test public void testTimeWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java index 03731f6d6..1da9f48e5 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java @@ -23,13 +23,14 @@ import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 11/01/2017. */ -public class TestCustomTransformJsonYaml { +public class TestCustomTransformJsonYaml extends BaseND4JTest { @Test public void testCustomTransform() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java index d09995009..dd6e0941a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java @@ -61,6 +61,7 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -70,7 +71,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index e17650a87..ac69e3397 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -21,6 +21,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 21/03/2016. */ -public class TestReduce { +public class TestReduce extends BaseND4JTest { @Test public void testReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java index 38ec1fda9..daa5c15c8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -58,7 +59,7 @@ import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; -public class RegressionTestJson { +public class RegressionTestJson extends BaseND4JTest { @Test public void regressionTestJson100a() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java index 9f647d365..00c4b745f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 600ee0b25..1d440913b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -56,6 +56,7 @@ import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Assert; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -72,7 +73,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/03/2016. */ -public class TestTransforms { +public class TestTransforms extends BaseND4JTest { public static Schema getSchema(ColumnType type, String... colNames) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index 78d929e65..c6dad8359 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableTransforms { +public class TestNDArrayWritableTransforms extends BaseND4JTest { @Test public void testNDArrayWritableBasic() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java index 7eb3efdef..394457443 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java @@ -27,6 +27,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index 8ec5233c7..4c2c718ae 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -20,6 +20,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 10/22/16. */ -public class ParseDoubleTransformTest { +public class ParseDoubleTransformTest extends BaseND4JTest { @Test public void testDoubleTransform() { List record = new ArrayList<>(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index 0dffb6dab..64f6a4422 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -35,6 +35,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -46,7 +47,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 25/03/2016. */ -public class TestUI { +public class TestUI extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java index 9b95bbfb4..b68ae43ee 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java @@ -18,6 +18,7 @@ package org.datavec.api.util; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.BufferedReader; import java.io.File; @@ -33,7 +34,7 @@ import static org.hamcrest.core.IsEqual.equalTo; /** * @author raver119@gmail.com */ -public class ClassPathResourceTest { +public class ClassPathResourceTest extends BaseND4JTest { private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java index 1545938f6..d47ec60d7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java @@ -20,6 +20,7 @@ import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; @@ -27,7 +28,7 @@ import java.util.List; import static org.junit.Assert.assertArrayEquals; -public class TimeSeriesUtilsTest { +public class TimeSeriesUtilsTest extends BaseND4JTest { @Test public void testTimeSeriesCreation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index 6dfacdd93..dbc62ed93 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -16,6 +16,7 @@ package org.datavec.api.writable; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; @@ -31,7 +32,7 @@ import java.util.TimeZone; import static org.junit.Assert.assertEquals; -public class RecordConverterTest { +public class RecordConverterTest extends BaseND4JTest { @Test public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); @@ -86,7 +87,7 @@ public class RecordConverterTest { new IntWritable(1)); INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT); - INDArray act = RecordConverter.toArray(l); + INDArray act = RecordConverter.toArray(DataType.FLOAT, l); assertEquals(exp, act); } @@ -101,7 +102,7 @@ public class RecordConverterTest { {1,2,3,4,5}, {6,7,8,9,10}}).castTo(DataType.FLOAT); - INDArray act = RecordConverter.toMatrix(Arrays.asList(l1,l2)); + INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2)); assertEquals(exp, act); } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java index 81c2f2d73..9242927e1 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -28,7 +29,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableAndSerialization { +public class TestNDArrayWritableAndSerialization extends BaseND4JTest { @Test public void testIsValid() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index 93d7ed31b..bd636e62b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,9 +32,7 @@ import java.util.List; import static org.junit.Assert.*; -public class WritableTest { - - +public class WritableTest extends BaseND4JTest { @Test public void testWritableEqualityReflexive() { diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml index 04420a5e9..60409bc53 100644 --- a/datavec/datavec-arrow/pom.xml +++ b/datavec/datavec-arrow/pom.xml @@ -49,6 +49,12 @@ arrow-format ${arrow.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index 1d8fddc0e..edd036f0a 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -40,6 +40,7 @@ import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowConverterTest { +public class ArrowConverterTest extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..f2cf7ce09 --- /dev/null +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.arrow; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.BaseND4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.arrow"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java index 390bfdcd9..59ba5a546 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.primitives.Triple; import java.io.File; @@ -41,7 +42,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class RecordMapperTest { +public class RecordMapperTest extends BaseND4JTest { @Test public void testMultiWrite() throws Exception { diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index 6951561cd..e49a9fcc4 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowWritableRecordTimeSeriesBatchTests { +public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml index 1f99eab7c..3b9674cd9 100644 --- a/datavec/datavec-data/datavec-data-audio/pom.xml +++ b/datavec/datavec-data/datavec-data-audio/pom.xml @@ -57,6 +57,13 @@ with-dependencies + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + - + datavec-parent org.datavec @@ -31,6 +32,12 @@ datavec-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + com.maxmind.geoip2 geoip2 diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..7d4a6836c --- /dev/null +++ b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.transform; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-data/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml index 7b74ead38..a6c72b968 100644 --- a/datavec/datavec-data/datavec-hadoop/pom.xml +++ b/datavec/datavec-data/datavec-hadoop/pom.xml @@ -60,6 +60,13 @@ + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..2aaf25041 --- /dev/null +++ b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.hadoop; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.hadoop"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml index 00fc890d8..49dc26db8 100644 --- a/datavec/datavec-excel/pom.xml +++ b/datavec/datavec-excel/pom.xml @@ -51,6 +51,13 @@ poi-ooxml ${poi.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1b61f7f6c --- /dev/null +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.poi.excel; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.poi.excel"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml index bfafd25d0..6ef9b0441 100644 --- a/datavec/datavec-jdbc/pom.xml +++ b/datavec/datavec-jdbc/pom.xml @@ -58,6 +58,13 @@ ${derby.version} test + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1db810b7b --- /dev/null +++ b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.records.reader; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.records.reader"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml index 5c2c6f4ac..3adc0e011 100644 --- a/datavec/datavec-local/pom.xml +++ b/datavec/datavec-local/pom.xml @@ -81,6 +81,13 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..991b8466d --- /dev/null +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.local.transforms; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.local.transforms"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java index ba7048547..1a46789ad 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java @@ -28,6 +28,7 @@ import org.datavec.local.transforms.AnalyzeLocal; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; @@ -63,7 +64,7 @@ public class TestAnalyzeLocal { list.add(rr.next()); } - INDArray arr = RecordConverter.toMatrix(list); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list); INDArray mean = arr.mean(0); INDArray std = arr.std(0); diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 55cf6c5da..526b8238a 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -64,6 +64,13 @@ nd4j-native-api ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..83aa2fe5a --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.python"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index 10cca8e5a..57e50d127 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -51,6 +51,13 @@ datavec-spark-inference-model ${project.parent.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..3bff86e98 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.transform.client; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.transform.client"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml index 470340dc1..bac20d42e 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml @@ -45,6 +45,13 @@ datavec-local ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java index 67d7fe44a..f76e9885f 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java @@ -33,6 +33,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.SequenceBatchCSVRecord; import org.datavec.spark.transform.model.SingleCSVRecord; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; @@ -91,7 +92,7 @@ public class CSVSparkTransform { transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema()); List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - INDArray convert = RecordConverter.toArray(finalRecord); + INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 47951b1aa..0c05f327b 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -164,6 +164,13 @@ spark-core_2.11 ${spark.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 50194c91e..345b774c3 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -130,6 +130,12 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..9539251e6 --- /dev/null +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 5352ec10d..fcc20d661 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -104,7 +104,7 @@ public class NormalizationTests extends BaseSparkTest { } - INDArray arr = RecordConverter.toMatrix(data); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, data); Schema schema = builder.build(); JavaRDD> rdd = sc.parallelize(data); @@ -127,9 +127,9 @@ public class NormalizationTests extends BaseSparkTest { zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); INDArray zeroMeanUnitVarianceDataFrame = - RecordConverter.toMatrix(Normalization.zeromeanUnitVariance(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.zeromeanUnitVariance(schema, rdd).collect()); INDArray zeroMeanUnitVarianceDataFrameZeroToOne = - RecordConverter.toMatrix(Normalization.normalize(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.normalize(schema, rdd).collect()); assertEquals(standardScalered, zeroMeanUnitVarianceDataFrame); assertTrue(zeroToOnes.equalsWithEps(zeroMeanUnitVarianceDataFrameZeroToOne, 1e-1)); diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml index 23e030df3..5a4ba921d 100644 --- a/deeplearning4j/deeplearning4j-common-tests/pom.xml +++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml @@ -37,6 +37,11 @@ nd4j-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + ch.qos.logback logback-classic diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 496bb6b1b..90c88d4c3 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -164,20 +164,6 @@ oshi-core ${oshi.version} - - - - org.reflections - reflections - ${reflections.version} - test - - - com.google.code.findbugs - * - - - diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java index 5f0567094..34d4db39e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -17,15 +17,8 @@ package org.deeplearning4j; import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.reflections.Reflections; -import org.reflections.scanners.MethodAnnotationsScanner; -import org.reflections.util.ClasspathHelper; -import org.reflections.util.ConfigurationBuilder; - -import java.lang.reflect.Method; import java.util.*; - -import static org.junit.Assert.assertEquals; +import org.nd4j.AbstractAssertTestsClass; /** * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) @@ -33,45 +26,24 @@ import static org.junit.Assert.assertEquals; * Other than a small set of exceptions, all tests must extend this * * @author Alex Black + * @author Alexander Stoyakin */ @Slf4j -public class AssertTestsExtendBaseClass extends BaseDL4JTest { +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { @Override - public long getTimeoutMilliseconds() { - return 240000L; + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; } - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - private static final Set> exclusions = new HashSet<>(); + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } - @Test - public void checkTestClasses(){ - - Reflections reflections = new Reflections(new ConfigurationBuilder() - .setUrls(ClasspathHelper.forPackage("org.deeplearning4j")) - .setScanners(new MethodAnnotationsScanner())); - Set methods = reflections.getMethodsAnnotatedWith(Test.class); - Set> s = new HashSet<>(); - for(Method m : methods){ - s.add(m.getDeclaringClass()); - } - - List> l = new ArrayList<>(s); - Collections.sort(l, new Comparator>() { - @Override - public int compare(Class aClass, Class t1) { - return aClass.getName().compareTo(t1.getName()); - } - }); - - int count = 0; - for(Class c : l){ - if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){ - log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c); - count++; - } - } - assertEquals("Number of tests not extending BaseDL4JTest", 0, count); + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index cf75700f8..9bcb97b7d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -96,7 +96,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { SDVariable z1 = a0.mmul(w1).add("prediction", b1); SDVariable a1 = sd.nn().softmax("softmax", z1); - SDVariable diff = sd.f().squaredDifference(a1, label); + SDVariable diff = sd.math().squaredDifference(a1, label); SDVariable lossMse = diff.mean(); lossMse.markAsLoss(); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index 2f47c2c8b..92d5d579e 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -494,7 +494,7 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S List c = list.get(i); if (details.entireReader) { //Convert entire reader contents, without modification - INDArray converted = RecordConverter.toArray(c); + INDArray converted = RecordConverter.toArray(Nd4j.defaultFloatingPointType(), c); putExample(arr, converted, i); } else if (details.oneHot) { //Convert a single column to a one-hot representation diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index ebc6740d9..645b4eca2 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -57,7 +57,6 @@ ${project.version} test - diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java new file mode 100644 index 000000000..7341a3a2c --- /dev/null +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.graph; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendedBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.graph"; + } + + @Override + protected Class getBaseClass() {return BaseDL4JTest.class; } +} + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..d7c03956f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.text.tokenization.tokenizer; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..c767c3e72 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,53 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import com.atilika.kuromoji.TestUtils; +import com.atilika.kuromoji.ipadic.RandomizedInputTest; +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + exclusions.add(TestUtils.class); + exclusions.add(RandomizedInputTest.class); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..ccf95a8ea --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..85b0c39a9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ + +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index f27fd7a94..668c728ae 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -14,76 +14,77 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - - 4.0.0 - - org.deeplearning4j - deeplearning4j-nlp-parent - 1.0.0-SNAPSHOT - + + 4.0.0 + + org.deeplearning4j + deeplearning4j-nlp-parent + 1.0.0-SNAPSHOT + - deeplearning4j-nlp + deeplearning4j-nlp - - - org.nd4j - nd4j-native-api - ${nd4j.version} - + + + org.nd4j + nd4j-native-api + ${nd4j.version} + - - commons-lang - commons-lang - 2.6 - - - org.deeplearning4j - deeplearning4j-core - ${project.version} - + + commons-lang + commons-lang + 2.6 + + + org.deeplearning4j + deeplearning4j-core + ${project.version} + - - org.threadly - threadly - ${threadly.version} - + + org.threadly + threadly + ${threadly.version} + - - junit - junit - test - + + junit + junit + test + - - org.mockito - mockito-core - ${mockito.version} - test - + + org.mockito + mockito-core + ${mockito.version} + test + - - ch.qos.logback - logback-classic - test - - - org.apache.commons - commons-lang3 - ${commonslang.version} - - - com.github.vinhkhuc - jfasttext - 0.4 - + + ch.qos.logback + logback-classic + test + + + org.apache.commons + commons-lang3 + ${commonslang.version} + + + com.github.vinhkhuc + jfasttext + 0.4 + - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..6fb3b0316 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 60ecbf057..8fedee7b0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -179,10 +180,10 @@ public class LocallyConnected1D extends SameDiffLayer { //NCW format. if(cm == ConvolutionMode.Same) { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), PadMode.CONSTANT, 0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), PadMode.CONSTANT, 0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 5044017a0..6fad9ec69 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -32,6 +32,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -185,10 +186,10 @@ public class LocallyConnected2D extends SameDiffLayer { //NCHW format if(cm == ConvolutionMode.Same){ layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), PadMode.CONSTANT, 0.0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), PadMode.CONSTANT, 0.0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 712265e05..bcca695df 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -34,6 +34,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -295,7 +296,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex { } //Define the function for external errors: - fn = sameDiff.f().externalErrors(layerOutput); + fn = SameDiffUtils.externalErrors(sameDiff, null, layerOutput); fn.outputVariable(); this.outputKey = outputVar.name(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 64c2ea25e..ed355fdaf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -29,6 +29,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -321,7 +322,7 @@ public class SameDiffLayer extends AbstractLayer { } //Define the function for external errors: - fn = sameDiff.f().externalErrors(layerOutput); + fn = SameDiffUtils.externalErrors(sameDiff, null,layerOutput); fn.outputVariable(); this.outputKey = outputVar.name(); diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1d8c3d578 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.remote; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.remote"; + } + + @Override + protected Class getBaseClass() { return BaseDL4JTest.class; } +} + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 8a629bc66..e21b2d270 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -475,6 +475,11 @@ public abstract class DifferentialFunction { return outputVariables()[0]; } + public List outputs(){ + SDVariable[] out = outputVariables(); + return out == null ? null : Arrays.asList(out); + } + public String[] outputVariablesNames(){ SDVariable[] outputVars = outputVariables(); @@ -502,14 +507,6 @@ public abstract class DifferentialFunction { */ public abstract List doDiff(List f1); - /** - * Shortcut for the {@link DifferentialFunctionFactory} - * @return - */ - public DifferentialFunctionFactory f() { - return sameDiff.f(); - } - /** * Return the arguments for a given function @@ -576,7 +573,7 @@ public abstract class DifferentialFunction { copied = true; } - SDVariable gradVar = f().add(grad, vals.get(i)); + SDVariable gradVar = var.getSameDiff().math.add(grad, vals.get(i)); vals.set(i, gradVar); sameDiff.setGradientForVariableName(var.name(), gradVar); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java deleted file mode 100644 index 093e3099b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ /dev/null @@ -1,2659 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.functions; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.Data; -import lombok.NonNull; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.nd4j.autodiff.loss.LossReduce; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.enums.DataFormat; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.blas.params.MMulTranspose; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.NoOp; -import org.nd4j.linalg.api.ops.custom.*; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; -import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex; -import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.convolution.*; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; -import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss; -import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss; -import org.nd4j.linalg.api.ops.impl.loss.HingeLoss; -import org.nd4j.linalg.api.ops.impl.loss.HuberLoss; -import org.nd4j.linalg.api.ops.impl.loss.L2Loss; -import org.nd4j.linalg.api.ops.impl.loss.LogLoss; -import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; -import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits; -import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; -import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; -import org.nd4j.linalg.api.ops.impl.reduce.Moments; -import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; -import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; -import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; -import org.nd4j.linalg.api.ops.impl.reduce.bool.All; -import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.*; -import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; -import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; -import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; -import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; -import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMax; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMin; -import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; -import org.nd4j.linalg.api.ops.impl.reduce.same.Max; -import org.nd4j.linalg.api.ops.impl.reduce.same.Min; -import org.nd4j.linalg.api.ops.impl.reduce.same.Prod; -import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.Dot; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.scalar.*; -import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.*; -import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; -import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; -import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer; -import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN; -import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; -import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; -import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil; -import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; -import org.nd4j.linalg.api.ops.impl.transforms.same.Floor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; -import org.nd4j.linalg.api.ops.impl.transforms.same.Negative; -import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal; -import org.nd4j.linalg.api.ops.impl.transforms.same.Round; -import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; -import org.nd4j.linalg.api.ops.impl.transforms.same.Square; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.random.custom.*; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; -import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; -import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; -import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.Range; -import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.UniformDistribution; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.util.ArrayUtil; - -/** - * - */ -@Data -public class DifferentialFunctionFactory { - - protected SameDiff sameDiff; - private static Map methodNames; - - /** - * @param sameDiff - */ - public DifferentialFunctionFactory(SameDiff sameDiff) { - if (sameDiff != null) { - this.sameDiff = sameDiff; - if (methodNames == null) { - methodNames = new HashMap<>(); - Method[] methods = getClass().getDeclaredMethods(); - for (Method method : methods) - methodNames.put(method.getName().toLowerCase(), method); - } - } else { - throw new IllegalArgumentException("Input not null value."); - } - - - } - - public SameDiff sameDiff() { - return sameDiff; - } - - - public SDVariable invoke(String name, Object[] args) { - try { - return (SDVariable) methodNames.get(name).invoke(this, args); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public ExternalErrorsFunction externalErrors(SDVariable... inputs) { - return externalErrors(null, inputs); - } - - public ExternalErrorsFunction externalErrors(Map externalGradients, SDVariable... inputs) { - Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + - " be specified when using external errors: got %s", inputs); - ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff(), Arrays.asList(inputs), externalGradients); - fn.outputVariable(); - return fn; - } - - public SDVariable zerosLike(SDVariable input) { - return zerosLike(null, input); - } - - public SDVariable zerosLike(String name, SDVariable input) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input).outputVariable(); - } - - public SDVariable zerosLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { - return create(name, shape, 'c', initialize, dataType); - } - - public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) { - validateDifferentialFunctionsameDiff(shape); - return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable(); - } - - public SDVariable onesLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { - return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); - } - - public SDVariable range(double from, double to, double step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable[] listdiff(SDVariable x, SDVariable y){ - return new ListDiff(sameDiff(), x, y).outputVariables(); - } - - public SDVariable cast(SDVariable toCast, DataType toType){ - return new Cast(sameDiff(), toCast, toType).outputVariable(); - } - - public SDVariable[] meshgrid(boolean cartesian, SDVariable... inputs) { - return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables(); - } - - public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) { - return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable(); - } - - public SDVariable randomUniform(double min, double max, long... shape) { - return new UniformDistribution(sameDiff(), min, max, shape).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, SDVariable shape) { - return new RandomNormal(sameDiff(), shape, mean, std).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, long... shape) { - return new GaussianDistribution(sameDiff(), mean, std, shape).outputVariable(); - } - - public SDVariable randomBernoulli(double p, SDVariable shape) { - return new RandomBernoulli(sameDiff(), shape, p).outputVariable(); - } - - public SDVariable randomBernoulli(double p, long... shape) { - return new BernoulliDistribution(sameDiff(), p, shape).outputVariable(); - } - - public SDVariable randomBinomial(int nTrials, double p, long... shape) { - return new BinomialDistribution(sameDiff(), nTrials, p, shape).outputVariable(); - } - - public SDVariable randomLogNormal(double mean, double stdev, long... shape) { - return new LogNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomNormalTruncated(double mean, double stdev, long... shape) { - return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) { - return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable(); - } - - public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) { - return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable(); - } - - public SDVariable randomShuffle(SDVariable values, int... seeds) { - return new RandomShuffle(sameDiff(), values, seeds).outputVariable(); - } - - /** - * Exponential distribution: P(x) = lambda * exp(-lambda * x) - * - * @param lambda Must be > 0 - * @param shape Shape of the output - */ - public SDVariable randomExponential(double lambda, SDVariable shape) { - return new RandomExponential(sameDiff(), shape, lambda).outputVariable(); - } - - - public SDVariable pad(SDVariable input, SDVariable padding, Pad.Mode mode, double padValue){ - return new Pad(sameDiff(), input, padding, mode, padValue).outputVariable(); - } - - /** - * Local response normalization operation. - * - * @param input the inputs to lrn - * @param lrnConfig the configuration - * @return - */ - public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { - LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input}) - .sameDiff(sameDiff()) - .config(lrnConfig) - .build(); - - return lrn.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input, weights}) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param bias conv1d bias - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) { - - SDVariable[] args; - - if(bias == null){ - args = new SDVariable[]{input, weights}; - } else { - args = new SDVariable[]{input, weights, bias}; - } - - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(args) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - Conv2D conv2D = Conv2D.sameDiffBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .config(conv2DConfig) - .build(); - - return conv2D.outputVariable(); - } - - public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2d(sameDiff(), input, nchw, scaleH, scaleW).outputVariable(); - } - - public SDVariable upsampling2dBp(SDVariable input, SDVariable gradient, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2dDerivative(sameDiff(), input, gradient, nchw, scaleH, scaleW).outputVariable(); - } - - - /** - * Average pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return avgPooling2D.outputVariable(); - } - - /** - * Max pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return maxPooling2D.outputVariable(); - } - - /** - * Avg pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.AVG); - return new AvgPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Max pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.MAX); - return new MaxPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Separable Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - SConv2D sconv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(conv2DConfig) - .build(); - - return sconv2D.outputVariable(); - } - - - /** - * Depth-wise Conv2d operation. This is just separable convolution with - * only the depth-wise weights specified. - * - * @param inputs the inputs to conv2d - * @param depthConv2DConfig the configuration - * @return - */ - public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { - SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(depthConv2DConfig) - .build(); - - return depthWiseConv2D.outputVariable(); - } - - - /** - * Deconv2d operation. - * - * @param inputs the inputs to conv2d - * @param deconv2DConfig the configuration - * @return - */ - public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { - DeConv2D deconv2D = DeConv2D.sameDiffBuilder() - .inputs(inputs) - .sameDiff(sameDiff()) - .config(deconv2DConfig) - .build(); - - return deconv2D.outputVariable(); - } - - public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { - DeConv3D d = new DeConv3D(sameDiff(), input, weights, bias, config); - return d.outputVariable(); - } - - public SDVariable[] deconv3dDerivative(SDVariable input, SDVariable weights, SDVariable bias, SDVariable grad, DeConv3DConfig config) { - DeConv3DDerivative d = new DeConv3DDerivative(sameDiff(), input, weights, bias, grad, config); - return d.outputVariables(); - } - - /** - * Conv3d operation. - * - * @param inputs the inputs to conv3d - * @param conv3DConfig the configuration - * @return - */ - public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { - Conv3D conv3D = Conv3D.sameDiffBuilder() - .inputFunctions(inputs) - .config(conv3DConfig) - .sameDiff(sameDiff()) - .build(); - - val outputVars = conv3D.outputVariables(); - return outputVars[0]; - } - - - /** - * Batch norm operation. - */ - public SDVariable batchNorm(SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, - boolean applyGamma, boolean applyBeta, - double epsilon, int... axis) { - BatchNorm batchNorm = BatchNorm.builder() - .inputFunctions(new SDVariable[]{input, mean, variance, gamma, beta}) - .applyGamma(applyGamma) - .applyBeta(applyBeta) - .epsilon(epsilon) - .sameDiff(sameDiff()) - .axis(axis) - .build(); - - val outputVars = batchNorm.outputVariables(); - return outputVars[0]; - } - - public SDVariable im2Col(SDVariable input, Conv2DConfig config) { - return new Im2col(sameDiff(), input, config).outputVariable(); - } - - public SDVariable im2ColBp(SDVariable im2colInput, SDVariable gradientAtOutput, Conv2DConfig config) { - return new Im2colBp(sameDiff(), im2colInput, gradientAtOutput, config).outputVariable(); - } - - public SDVariable col2Im(SDVariable input, Conv2DConfig config) { - return new Col2Im(sameDiff(), input, config).outputVariable(); - } - - public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode){ - return new ExtractImagePatches(sameDiff(), input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode).outputVariable(); - } - - public SDVariable[] moments(SDVariable input, int... axes) { - return new Moments(sameDiff(), input, axes).outputVariables(); - } - - public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { - return new NormalizeMoments(sameDiff(), counts, means, variances, shift).outputVariables(); - } - - - public SDVariable tile(@NonNull SDVariable iX, @NonNull int[] repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable grad, @NonNull int[] repeat){ - return new TileBp(sameDiff, in, grad, repeat).outputVariable(); - } - - public SDVariable tile(@NonNull SDVariable iX, @NonNull SDVariable repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable repeat, @NonNull SDVariable grad){ - return new TileBp(sameDiff, in, repeat, grad).outputVariable(); - } - - public SDVariable dropout(SDVariable input, double p) { - return new DropOutInverted(sameDiff(), input, p).outputVariable(); - } - - - public SDVariable sum(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Sum(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable sumBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new SumBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable prod(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Prod(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable prodBp(SDVariable preReduceInput, SDVariable grad, boolean keepDims, int... dimensions) { - return new ProdBp(sameDiff(), preReduceInput, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable mean(SDVariable in, boolean keepDims, int... dimensions) { - return new Mean(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable meanBp(SDVariable in, SDVariable grad, boolean keepDims, int... dimensions) { - return new MeanBp(sameDiff(), in, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable std(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviation(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable stdBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviationBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - - public SDVariable variance(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new Variance(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable varianceBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new VarianceBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable standardize(SDVariable i_x, int... dimensions) { - return new Standardize(sameDiff(), i_x, dimensions).outputVariable(); - } - - public SDVariable standardizeBp(SDVariable stdInput, SDVariable gradient, int... dimensions) { - return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) { - return new SquaredNorm(sameDiff(), input, keepDims, dimensions).outputVariable(); - } - - public SDVariable squaredNormBp(SDVariable preReduceInput, SDVariable gradient, boolean keepDims, int... dimensions) { - return new SquaredNormBp(sameDiff(), preReduceInput, gradient, keepDims, dimensions).outputVariable(); - } - - public SDVariable entropy(SDVariable in, int... dimensions) { - return new Entropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable logEntropy(SDVariable in, int... dimensions) { - return new LogEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable shannonEntropy(SDVariable in, int... dimensions){ - return new ShannonEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable countNonZero(SDVariable input, int... dimensions) { - return new CountNonZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable countZero(SDVariable input, int... dimensions) { - return new CountZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable zeroFraction(SDVariable input) { - return new ZeroFraction(sameDiff(), input).outputVariable(); - } - - public SDVariable scalarMax(SDVariable in, Number num) { - return new ScalarMax(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarMin(SDVariable in, Number num) { - return new ScalarMin(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarSet(SDVariable in, Number num) { - return new ScalarSet(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarFloorMod(SDVariable in, Number num) { - return new ScalarFMod(sameDiff(), in, num).outputVariable(); - } - - public SDVariable max(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Max(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable max(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable maxBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MaxBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable min(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Min(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable minBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MinBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable min(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable amax(SDVariable in, int... dimensions) { - return new AMax(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amin(SDVariable in, int... dimensions) { - return new AMin(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amean(SDVariable in, int... dimensions) { - return new AMean(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable asum(SDVariable in, int... dimensions) { - return new ASum(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { - return new IMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { - return new IMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new FirstIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new LastIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new MatchCondition(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - * - * @param in Input - * @param condition Condition - * @return Boolean mask - */ - public SDVariable matchCondition(SDVariable in, Condition condition) { - return new MatchConditionTransform(sameDiff(), in, condition).outputVariable(); - } - - public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumSum(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumsumBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumSumBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumProd(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprodBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { - return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); - } - - public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) { - return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables(); - } - - public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm1(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm1Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm1Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm2(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm2Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmax(SDVariable i_x, boolean keepDims, int... dimensions) { - return new NormMax(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmaxBp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new NormMaxBp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ - return new ReductionShape(sameDiff(), shape, axis, keepDim).outputVariable(); - } - - /** - * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. - *

- * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, - * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. - * This is typically only used with reduction operations backprop. - * - * @param origRank Rank of the original array, before the reduction was executed - * @param reduceDims Dimensions that the original array was reduced from - * @param toExpand Array to add 1s to the shape to (such that it can be - * @return Reshaped array. - */ - public SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { - if (Shape.isWholeArray(origRank, reduceDims)) { - //Output is [1,1] which is already broadcastable - return toExpand; - } else if (origRank == 2 && reduceDims.length == 1) { - //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] - //both are already broadcastable - return toExpand; - } else { - //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] - for (int d : reduceDims) { - toExpand = sameDiff().expandDims(toExpand, d); - } - return toExpand; - } - } - - public SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { - SDVariable shape = origInput.shape(); - SDVariable reduceShape = reductionShape(shape, axis, true); - SDVariable reshaped = toExpand.reshape(reduceShape); - return reshaped; - } - - - public SDVariable gradientBackwardsMarker(SDVariable iX) { - return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable(); - } - - public SDVariable abs(SDVariable iX) { - return new Abs(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable neg(SDVariable iX) { - return new Negative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cos(SDVariable iX) { - return new Cos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sin(SDVariable iX) { - return new Sin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tan(SDVariable iX) { - return new Tan(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable permute(SDVariable iX, int... dimensions) { - return new Permute(sameDiff(), iX, dimensions).outputVariable(); - } - - public SDVariable permute(SDVariable in, SDVariable dimensions) { - return new Permute(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable noop(SDVariable input) { - return new NoOp(sameDiff(), input).outputVariable(); - } - - public SDVariable identity(SDVariable input) { - return new Identity(sameDiff(), input).outputVariable(); - } - - public SDVariable all(SDVariable input, int... dimensions) { - return new All(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable any(SDVariable input, int... dimensions) { - return new Any(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable invertPermutation(SDVariable input, boolean inPlace) { - return new InvertPermutation(sameDiff(), input, inPlace).outputVariable(); - } - - public SDVariable transpose(SDVariable iX) { - return new Transpose(sameDiff(), iX).outputVariable(); - } - - - public SDVariable acos(SDVariable iX) { - return new ACos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable asin(SDVariable iX) { - return new ASin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable atan(SDVariable iX) { - return new ATan(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable atan2(SDVariable y, SDVariable x) { - return new ATan2(sameDiff(), y, x).outputVariable(); - } - - - public SDVariable cosh(SDVariable iX) { - return new Cosh(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable sinh(SDVariable iX) { - return new Sinh(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tanh(SDVariable iX) { - return new Tanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable tanhRational(SDVariable in) { - return new RationalTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhRectified(SDVariable in) { - return new RectifiedTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhDerivative(SDVariable iX, SDVariable wrt) { - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) { - return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) { - return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * Use {@link #tanhRationalBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRationalDerivative(SDVariable in) { - return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - /** - * Use {@link #tanhRectifiedBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRectifiedDerivative(SDVariable in) { - return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - public SDVariable step(SDVariable in, double cutoff) { - return new Step(sameDiff(), in, false, cutoff).outputVariable(); - } - - - public SDVariable acosh(SDVariable iX) { - return new ACosh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable asinh(SDVariable iX) { - return new ASinh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable atanh(SDVariable iX) { - return new ATanh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable exp(SDVariable iX) { - return new Exp(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable expm1(SDVariable iX) { - return new Expm1(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable rsqrt(SDVariable iX) { - return new RSqrt(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable iX) { - return new Log(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable in, double base) { - return new LogX(sameDiff(), in, base).outputVariable(); - } - - public SDVariable log1p(SDVariable iX) { - return new Log1p(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable isFinite(SDVariable ix) { - return new IsFinite(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isInfinite(SDVariable ix) { - return new IsInf(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isNaN(SDVariable ix) { - return new IsNaN(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isMax(SDVariable ix) { - return new IsMax(sameDiff(), ix).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) { - return new CompareAndReplace(sameDiff(), to, from, condition).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, Number set, Condition condition) { - return new CompareAndSet(sameDiff(), to, set, condition).outputVariable(); - } - - public SDVariable round(SDVariable ix) { - return new Round(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable or(SDVariable iX, SDVariable i_y) { - return new Or(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable and(SDVariable ix, SDVariable iy) { - return new And(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable xor(SDVariable ix, SDVariable iy) { - return new Xor(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable shift(SDVariable ix, SDVariable shift) { - return new ShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rshift(SDVariable ix, SDVariable shift) { - return new RShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotl(SDVariable ix, SDVariable shift) { - return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotr(SDVariable ix, SDVariable shift) { - return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { - return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ - return new BitwiseAnd(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseOr(SDVariable x, SDVariable y){ - return new BitwiseOr(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseXor(SDVariable x, SDVariable y){ - return new BitwiseXor(sameDiff(), x, y).outputVariable(); - } - - public SDVariable eq(SDVariable iX, SDVariable i_y) { - return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - - public SDVariable neq(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, true).outputVariable(); - } - - public SDVariable neq(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - public SDVariable pow(SDVariable iX, double i_y) { - return new Pow(sameDiff(), iX, false, i_y).outputVariable(); - } - - public SDVariable pow(SDVariable x, SDVariable y){ - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sameDiff(), x, y).outputVariable(); - } - - public SDVariable sqrt(SDVariable iX) { - return new Sqrt(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable square(SDVariable iX) { - return new Square(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cube(SDVariable iX) { - return new Cube(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable cubeBp(SDVariable in, SDVariable epsilon) { - return new CubeBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #cubeBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable cubeDerivative(SDVariable iX) { - return new CubeDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable floor(SDVariable iX) { - return new Floor(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable floorDiv(SDVariable x, SDVariable y) { - return new FloorDivOp(sameDiff(), x, y).outputVariable(); - } - - public List floorDivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable floorMod(SDVariable x, SDVariable y) { - return new FloorModOp(sameDiff(), x, y).outputVariable(); - } - - public List floorModBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable ceil(SDVariable x) { - return new Ceil(sameDiff(), x).outputVariable(); - } - - public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { - return new ClipByValue(sameDiff(), x, clipValueMin, clipValueMax).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue) { - return new ClipByNorm(sameDiff(), x, clipValue).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { - return new ClipByNorm(sameDiff(), x, clipValue, dimensions).outputVariable(); - } - - public SDVariable relu(SDVariable iX, double cutoff) { - return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable reluDerivative(SDVariable input, SDVariable grad){ - return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); - } - - public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable(); - } - - public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - public SDVariable relu6(SDVariable iX, double cutoff) { - return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable relu6Derivative(SDVariable iX, SDVariable wrt, double cutoff) { - return new Relu6Derivative(sameDiff(), iX, wrt, cutoff).outputVariable(); - } - - public SDVariable softmax(SDVariable iX) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}).outputVariable(); - } - - public SDVariable softmax(SDVariable iX, int dimension) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}, dimension).outputVariable(); - } - - - public SDVariable hardTanh(SDVariable iX) { - return new HardTanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) { - return new HardTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable hardTanhDerivative(SDVariable iX) { - return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardSigmoid(SDVariable in) { - return new HardSigmoid(sameDiff(), in, false).outputVariable(); - } - - public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){ - return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable sigmoid(SDVariable iX) { - return new Sigmoid(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt) { - return new SigmoidDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - - public SDVariable logSigmoid(SDVariable iX) { - return new LogSigmoid(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable powDerivative(SDVariable iX, double pow) { - return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); - } - - public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) { - return new PowBp(sameDiff(), x, pow, gradient).outputVariables(); - } - - public SDVariable mishDerivative(SDVariable iX) { - return new MishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable swish(SDVariable iX) { - return new Swish(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable swishDerivative(SDVariable iX) { - return new SwishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable gelu(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELU(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELU(sameDiff(), iX, false, precise).outputVariable(); - } - - public SDVariable geluDerivative(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELUDerivative(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELUDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable sign(SDVariable iX) { - return new Sign(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable expandDims(SDVariable iX, int axis) { - return new ExpandDims(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - } - - public SDVariable squeeze(SDVariable iX, int... axis) { - return new Squeeze(sameDiff(), iX, axis).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { - return new ConfusionMatrix(sameDiff(), labels, pred, dataType).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, weights).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses, weights).outputVariable(); - } - - public SDVariable matrixDeterminant(SDVariable in){ - return new MatrixDeterminant(sameDiff(), in, false).outputVariable(); - } - - public SDVariable matrixInverse(SDVariable in){ - return new MatrixInverse(sameDiff(), in, false).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth) { - return new OneHot(sameDiff(), indices, depth).outputVariable(); - } - - public SDVariable reciprocal(SDVariable a) { - return new Reciprocal(sameDiff(), a).outputVariable(); - } - - - public SDVariable repeat(SDVariable iX, int axis) { - return new Repeat(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - - } - - public SDVariable stack(SDVariable[] values, int axis) { - return new Stack(sameDiff(), values, axis).outputVariable(); - } - - public SDVariable parallel_stack(SDVariable[] values) { - return new ParallelStack(sameDiff(), values).outputVariable(); - } - - public SDVariable[] unstack(SDVariable value, int axis) { - return new Unstack(sameDiff(), value, axis).outputVariables(); - } - - public SDVariable[] unstack(SDVariable value, int axis, int num) { - return new Unstack(sameDiff(), value, axis, num).outputVariables(); - } - - public SDVariable assign(SDVariable x, SDVariable y) { - return new Assign(sameDiff(), x, y).outputVariable(); - } - - public SDVariable assign(SDVariable x, Number num) { - return new ScalarSet(sameDiff(), x, num).outputVariable(); - } - - - public SDVariable softsign(SDVariable iX) { - return new SoftSign(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable softsignBp(SDVariable in, SDVariable epsilon) { - return new SoftSignBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #softsignBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable softsignDerivative(SDVariable iX) { - return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable softplus(SDVariable iX) { - return new SoftPlus(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable elu(SDVariable iX) { - return new ELU(sameDiff(), iX).outputVariable(); - - } - - public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { - return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); - } - - - public SDVariable leakyRelu(SDVariable iX, double alpha) { - return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable(); - - } - - public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) { - return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - /** - * @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)} - */ - @Deprecated - public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { - return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){ - return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable(); - } - - public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){ - return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables(); - } - - public SDVariable reshape(SDVariable iX, int[] shape) { - return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, long[] shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, SDVariable shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reverse(SDVariable x, int... dimensions) { - return new Reverse(sameDiff(), x, dimensions).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seq_dim, int batch_dim) { - return new ReverseSequence(sameDiff(), x, seq_lengths, seq_dim, batch_dim).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { - return new ReverseSequence(sameDiff(), x, seq_lengths).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, dataType).outputVariable(); - } - - public SDVariable concat(int dimension, SDVariable... inputs) { - return new Concat(sameDiff(), dimension, inputs).outputVariable(); - } - - public SDVariable fill(SDVariable shape, DataType dataType, double value) { - return new Fill(sameDiff(), shape, dataType, value).outputVariable(); - } - - public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { - return new Dot(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable[] dotBp(SDVariable in1, SDVariable in2, SDVariable grad, boolean keepDims, int... dimensions) { - return new DotBp(sameDiff(), in1, in2, grad, keepDims, dimensions).outputVariables(); - } - - public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) { - return new CosineSimilarity(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new CosineDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - - public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new EuclideanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - - public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new ManhattanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new HammingDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new JaccardDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) { - return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable(); - } - - public SDVariable lossL2(SDVariable var){ - return new L2Loss(sameDiff(), var).outputVariable(); - } - - public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossAbsoluteDifferenceBP(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossCosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLoss(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariable(); - } - - public SDVariable[] lossCosineDistanceBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLossBp(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariables(); - } - - public SDVariable lossHinge(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossHingeBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossHuber(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLoss(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariable(); - } - - public SDVariable[] lossHuberBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLossBp(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariables(); - } - - public SDVariable lossLog(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLoss(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariable(); - } - - public SDVariable[] lossLogBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLossBp(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariables(); - } - - public SDVariable lossLogPoisson(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossLogPoissonBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossLogPoissonFull(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label, true).outputVariable(); - } - - public SDVariable[] lossLogPoissonFullBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label, true).outputVariables(); - } - - public SDVariable lossMeanPairwiseSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanPairwiseSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossMeanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossSigmoidCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSigmoidCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables(); - } - - public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); - } - - public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables(); - } - - - public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) { - return new XwPlusB(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { - return new ReluLayer(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable mmul(SDVariable x, - SDVariable y, - MMulTranspose mMulTranspose) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new Mmul(sameDiff(), x, y, mMulTranspose).outputVariable(); - } - - - public SDVariable mmul(SDVariable x, - SDVariable y) { - return mmul(x, y, MMulTranspose.allFalse()); - } - - public List mmulBp(SDVariable x, SDVariable y, SDVariable eps, MMulTranspose mt) { - return Arrays.asList(new MmulBp(sameDiff(), x, y, eps, mt).outputVariables()); - } - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB) { - return batchMmul(matricesA, matricesB, false, false); - } - - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB, - boolean transposeA, - boolean transposeB) { - return batchMmul(ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB); - } - - - public SDVariable[] batchMmul(SDVariable[] matrices, - boolean transposeA, - boolean transposeB) { - return new BatchMmul(sameDiff(), matrices, transposeA, transposeB).outputVariables(); - } - - - public SDVariable tensorMmul(SDVariable x, - SDVariable y, - int[][] dimensions) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new TensorMmul(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { - return new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, false).outputVariable(); - } - - public List dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, withWeights).outputVariables()); - } - - public List dotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new DotProductAttentionBp(sameDiff(), queries, keys, values, gradient, mask, scaled).outputVariables()); - } - - public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) { - return new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); - } - - public List multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights).outputVariables()); - } - - public List multiHeadDotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, gradient, mask, scaled).outputVariables()); - } - - public SDVariable softmaxDerivative(SDVariable functionInput, SDVariable wrt, Integer dimension) { - validateDifferentialFunctionsameDiff(functionInput); - return new SoftmaxBp(sameDiff(), functionInput, wrt, dimension).outputVariable(); - } - - - public SDVariable logSoftmax(SDVariable i_v) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v).outputVariable(); - - } - - - public SDVariable logSoftmax(SDVariable i_v, int dimension) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable(); - - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable(); - } - - public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { - return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); - } - - - public SDVariable selu(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELU(sameDiff(), arg, false).outputVariable(); - } - - public SDVariable seluBp(SDVariable in, SDVariable epsilon) { - validateDifferentialFunctionsameDiff(in); - return new SeluBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #seluBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable seluDerivative(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELUDerivative(sameDiff(), arg, false).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public List rsubBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RSubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List rdivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdivi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public SDVariable add(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - - } - - public SDVariable mergeAdd(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAddOp(sameDiff(), differentialFunctions, false).outputVariable(); - } - - public SDVariable mergeMax(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeMax(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable mergeAvg(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAvg(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable diag(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new Diag(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable diagPart(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new DiagPart(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable setDiag(SDVariable in, SDVariable diag) { - return new MatrixSetDiag(sameDiff(), in, diag, false).outputVariable(); - } - - - public SDVariable batchToSpace(SDVariable differentialFunction, int[] blocks, int[][] crops) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new BatchToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocks, crops, false) - .outputVariable(); - } - - public SDVariable spaceToBatch(SDVariable differentialFunction, int[] blocks, int[][] padding) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToBatch(sameDiff(), new SDVariable[]{differentialFunction}, blocks, padding, false) - .outputVariable(); - } - - public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable[] dynamicPartition(SDVariable differentialFunction, SDVariable partitions, int numPartitions) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DynamicPartition(sameDiff(), differentialFunction, partitions, numPartitions) - .outputVariables(); - } - - public SDVariable[] dynamicPartitionBp(SDVariable input, SDVariable partitions, SDVariable[] grads, int numPartitions){ - return new DynamicPartitionBp(sameDiff(), input, partitions, grads, numPartitions).outputVariables(); - } - - public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - - return new DynamicStitch(sameDiff(), indices, differentialFunctions).outputVariable(); - } - - public SDVariable segmentMax(SDVariable data, SDVariable segmentIds){ - return new SegmentMax(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMaxBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMin(SDVariable data, SDVariable segmentIds){ - return new SegmentMin(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMinBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMean(SDVariable data, SDVariable segmentIds){ - return new SegmentMean(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMeanBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentProd(SDVariable data, SDVariable segmentIds){ - return new SegmentProd(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentProdBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentSum(SDVariable data, SDVariable segmentIds){ - return new SegmentSum(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentSumBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - - public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMax(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMaxBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMin(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMinBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMean(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMeanBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentProd(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentProdBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSum(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSumBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSqrtN(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSqrtNBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSqrtNBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - - - - public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { - validateDifferentialFunctionsameDiff(df); - return new Dilation2D(sameDiff(), new SDVariable[]{df, weights}, strides, rates, isSameMode, false) - .outputVariable(); - } - - public SDVariable shape(SDVariable df) { - validateDifferentialFunctionsameDiff(df); - return new org.nd4j.linalg.api.ops.impl.shape.Shape(sameDiff(), df, false).outputVariable(); - } - - public SDVariable size(SDVariable in) { - return new Size(sameDiff(), in).outputVariable(); - } - - public SDVariable sizeAt(SDVariable in, int dimension){ - return new SizeAt(sameDiff(), in, dimension).outputVariable(); - } - - public SDVariable rank(SDVariable df) { - return new Rank(sameDiff(), df, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, int[] indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, SDVariable indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gatherNd(SDVariable df, SDVariable indices) { - validateDifferentialFunctionsameDiff(df); - return new GatherNd(sameDiff(), df, indices).outputVariable(); - } - - public SDVariable trace(SDVariable in){ - return new Trace(sameDiff(), in).outputVariable(); - } - - public SDVariable cross(SDVariable a, SDVariable b) { - validateDifferentialFunctionsameDiff(a); - return new Cross(sameDiff(), new SDVariable[]{a, b}).outputVariable(); - } - - public SDVariable erf(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erf(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable erfc(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erfc(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable addi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public List addBp(SDVariable x, SDVariable y, SDVariable grad) { - SDVariable[] ret = new AddBpOp(sameDiff(), x, y, grad).outputVariables(); - return Arrays.asList(ret); - } - - - public SDVariable sub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable squaredDifference(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SquaredDifferenceOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false) - .outputVariable(); - } - - - public List subBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new SubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable subi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List mulBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public List modBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable div(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable truncatedDiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new TruncateDivOp(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - public List divBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new DivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable divi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdiv(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdivi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable add(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - - public SDVariable addi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable sub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable subi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable muli(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable div(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable divi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable ltOrEqi(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable ltei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable eq(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable eqi(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - public SDVariable isNonDecreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNonDecreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isStrictlyIncreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsStrictlyIncreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isNumericTensor(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNumericTensor(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable slice(SDVariable input, int[] begin, int[] size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, int[] begin, int[] size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, SDVariable begin, SDVariable size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, SDVariable begin, SDVariable end, SDVariable strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterAdd(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterSub(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMul(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterDiv(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMax(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMin(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterUpdate(sameDiff(), ref, indices, updates).outputVariable(); - } - - - public SDVariable merge(SDVariable... inputs){ - return new Merge(sameDiff(), inputs).outputVariable(); - } - - public SDVariable[] switchOp(SDVariable input, SDVariable predicate){ - return new Switch(sameDiff(), input, predicate).outputVariables(); - } - - - public void validateDifferentialFunctionsameDiff( - SDVariable function) { - - Preconditions.checkState(function != null, "Passed in function was null."); - Preconditions.checkState(function.getSameDiff() == sameDiff); - - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained " + - "in same sameDiff. The left %s must match this function %s", function, this); - Preconditions.checkState(sameDiff == this.getSameDiff(), "Function applications must be " + - "contained in same sameDiff. The left %s must match this function ", function, this); - } - - - public void validateDifferentialFunctionGraph(SDVariable function) { - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained in same graph. The left %s must match this function %s", - function, this); - - } - - - /** - * @param func - * @param input - * @return - */ - public SDVariable doRepeat(SDVariable func, - SDVariable input) { - validateDifferentialFunctionsameDiff(func); - validateDifferentialFunctionsameDiff(input); - - return tile(func, ArrayUtil.toInts(input.getShape())); - } - - public SDVariable enter(SDVariable x, String frameName){ - return new Enter(sameDiff, frameName, x).outputVariable(); - } - - public SDVariable enter(SDVariable x, String frameName, boolean isConstant){ - return new Enter(sameDiff, frameName, x, isConstant).outputVariable(); - } - - public SDVariable exit(SDVariable x){ - return new Exit(sameDiff, x).outputVariable(); - } - - public SDVariable nextIteration(SDVariable x){ - return new NextIteration(sameDiff, x).outputVariable(); - } - - public SDVariable adjustContrast(SDVariable in, SDVariable factor) { - return new AdjustContrast(sameDiff, in, factor).outputVariable(); - } - - public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { - return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); - } - - public SDVariable bitCast(SDVariable in, SDVariable dataType) { - return new BitCast(sameDiff, in, dataType).outputVariable(); - } - - public SDVariable compareAndBitpack(SDVariable threshold) { - return new CompareAndBitpack(sameDiff, threshold).outputVariable(); - } - - public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { - return new DivideNoNan(sameDiff, in1, in2).outputVariable(); - } - - public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { - return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); - } - - public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max, - int num_bits, boolean narrow) { - return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable(); - } - - public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { - return new BetaInc(sameDiff, a, b, x).outputVariable(); - } - - public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset, - SDVariable dataFormat, SDVariable isTraining) { - return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables(); - } - - public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) { - return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); - } - - public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) { - return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); - } - - public SDVariable polygamma(SDVariable n, SDVariable x) { - return new Polygamma(sameDiff, n,x).outputVariable(); - } - - public SDVariable roll(SDVariable input, int shift) { - return new Roll(sameDiff, input, shift).outputVariable(); - } - - public SDVariable toggleBits(SDVariable x) { - return new ToggleBits(sameDiff, x).outputVariable(); - } - - - public String toString() { - return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 3b29e6ccb..5ee0801d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -253,7 +253,7 @@ public class SDVariable implements Serializable { * @return Negated variable */ public SDVariable neg(){ - return sameDiff.f().neg(this); + return sameDiff.math.neg(this); } /** @@ -579,7 +579,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable add(String varName, double scalar) { - val function = sameDiff.f().add(this,scalar); + val function = sameDiff.math.add(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -600,7 +600,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable add(String name, SDVariable x) { - val result = sameDiff.f().add(this, x); + val result = sameDiff.math.add(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -636,7 +636,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable sub(String varName, double scalar) { - val result = sameDiff.f().sub(this, scalar); + val result = sameDiff.math.sub(this, scalar); return sameDiff.updateVariableNameAndReference(result, varName); } @@ -657,7 +657,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable sub(String name, SDVariable x) { - val result = sameDiff.f().sub(this,x); + val result = sameDiff.math.sub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -693,7 +693,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable div(String varName, double scalar) { - val function = sameDiff.f().div(this,scalar); + val function = sameDiff.math.div(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -714,7 +714,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable div(String name, SDVariable x) { - val result = sameDiff.f().div(this, x); + val result = sameDiff.math.div(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -728,7 +728,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable fdiv(String name, SDVariable x) { - val result = sameDiff.f().floorDiv(this, x); + val result = sameDiff.math.floorDiv(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -742,7 +742,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mod(String name, SDVariable x) { - val result = sameDiff.f().mod(this, x); + val result = sameDiff.math.mod(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -762,7 +762,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable mul(String varName, double scalar) { - val function = sameDiff.f().mul(this, scalar); + val function = sameDiff.math.mul(this, scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -784,7 +784,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mul(String name, SDVariable x) { - val result = sameDiff.f().mul(this, x); + val result = sameDiff.math.mul(this, x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -820,7 +820,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable pow(String varName, double scalar) { - SDVariable ret = sameDiff.f().pow(this, scalar); + SDVariable ret = sameDiff.math.pow(this, scalar); return sameDiff.updateVariableNameAndReference(ret, varName); } @@ -840,7 +840,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rsub(String varName, double scalar) { - val function = sameDiff.f().rsub(this,scalar); + val function = sameDiff.math.rsub(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -861,7 +861,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rsub(String name, SDVariable x) { - val result = sameDiff.f().rsub(this,x); + val result = sameDiff.math.rsub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -881,7 +881,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rdiv(String varName, double scalar) { - val function = sameDiff.f().rdiv(this, scalar); + val function = sameDiff.math.rdiv(this, scalar); return sameDiff.updateVariableNameAndReference(function, varName); } @@ -902,34 +902,11 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rdiv(String name, SDVariable x) { - val result = sameDiff.f().rdiv(this,x); + val result = sameDiff.math.rdiv(this,x); return sameDiff.updateVariableNameAndReference(result,name); } - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(SDVariable sameDiffVariable) { - return truncatedDiv(null,sameDiffVariable); - - } - - - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) { - val function = sameDiff.f().truncatedDiv(this, sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - /** * See {@link #squaredDifference(String, SDVariable)} */ @@ -943,7 +920,7 @@ public class SDVariable implements Serializable { * @return squared difference between variables */ public SDVariable squaredDifference(String name, SDVariable x) { - val result = sameDiff.f().squaredDifference(this, x); + val result = sameDiff.math().squaredDifference(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -1431,7 +1408,7 @@ public class SDVariable implements Serializable { } public SDVariable permute(SDVariable dimensions){ - return sameDiff.permute(null, this, dimensions); + return sameDiff.permute( this, dimensions); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index c51ac28a1..77d46b889 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -24,7 +24,6 @@ import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; @@ -53,8 +52,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.transforms.Assert; @@ -95,7 +93,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; +import static org.nd4j.autodiff.util.SameDiffUtils.stackOutputs; /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. @@ -141,7 +139,7 @@ public class SameDiff extends SDBaseOps { //////////////////////////////////////// - private DifferentialFunctionFactory functionFactory; +// private DifferentialFunctionFactory functionFactory; // counter for auto-naming variables private int variableId = 0; @@ -296,15 +294,6 @@ public class SameDiff extends SDBaseOps { return this; } - /** - * Returns this samediff instance's {@link DifferentialFunctionFactory} - * - * @return DifferentialFunctionFactory - */ - public DifferentialFunctionFactory f() { - return functionFactory; - } - /** * Set the current SameDiff-wide {@link Listener} instances. * @@ -917,7 +906,6 @@ public class SameDiff extends SDBaseOps { private SameDiff() { super(null); super.sd = this; - functionFactory = new DifferentialFunctionFactory(this); sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); } @@ -5945,7 +5933,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[1]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[1]; } @@ -5955,7 +5943,7 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared.contains(trueOut.name())) { - SDVariable[] s = f().switchOp(trueOut, pred); + SDVariable[] s = switchOp(trueOut, pred); switches.put(trueOut.name(), s); trueOut = s[1]; } @@ -5975,7 +5963,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[0]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[0]; } @@ -5985,13 +5973,13 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared2.contains(falseOut.name())) { - SDVariable[] s = f().switchOp(falseOut, pred); + SDVariable[] s = switchOp(falseOut, pred); switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); - SDVariable output = f().merge(trueOut, falseOut); + SDVariable output = merge(trueOut, falseOut); ifScope.close(); @@ -6042,11 +6030,9 @@ public class SameDiff extends SDBaseOps { SDVariable[] entered = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - entered[i] = f().enter(loopVars[i], frameName); + entered[i] = new Enter(this, frameName, loopVars[i]).outputVariable(); } - //counter = SD.f().enter(counter, frameName); - SDVariable[] merged = new SDVariable[loopVars.length]; Merge[] mergeOps = new Merge[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ @@ -6072,19 +6058,16 @@ public class SameDiff extends SDBaseOps { SDVariable[] trueSwitches = new SDVariable[loopVars.length]; SDVariable[] exits = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable[] s = f().switchOp(merged[i], cond_result); + SDVariable[] s = switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; alreadyEntered.add(s[1].name()); - exits[i] = f().exit(s[0]); + exits[i] = new Exit(this, s[0]).outputVariable(); } - //SDVariable[] cs = SD.f().switchOp(counter, cond_result); - //SDVariable counterExit = SD.f().exit(cs[0]); - //counter = cs[1]; - final Set declared = Sets.newHashSet(this.variableMap().keySet()); final Map done = new HashMap<>(); + final SameDiff sd = this; this.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { @@ -6098,7 +6081,7 @@ public class SameDiff extends SDBaseOps { if(done.containsKey(argument.name())) return done.get(argument.name()); - SDVariable e = f().enter(argument, frameName, true); + SDVariable e = new Enter(sd, frameName, argument, true).outputVariable(); done.put(argument.name(), e); return e; } @@ -6112,7 +6095,7 @@ public class SameDiff extends SDBaseOps { //counter.add(1); for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable n = f().nextIteration(outs[i]); + SDVariable n = new NextIteration(this, outs[i]).outputVariable(); mergeOps[i].replaceArg(1,n); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index 5cafa09aa..193229ff9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -27,7 +27,7 @@ import lombok.Setter; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.util.TrainingUtils; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; @@ -165,7 +165,7 @@ public class OutputConfig { Preconditions.checkState(outputs.size() == 1, "Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size()); - return TrainingUtils + return SameDiffUtils .getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 3b53e5b65..157ec1fbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -37,12 +37,11 @@ public class SDBaseOps { /** * Boolean and array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); } @@ -51,12 +50,11 @@ public class SDBaseOps { * Boolean and array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -65,12 +63,11 @@ public class SDBaseOps { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); } @@ -79,12 +76,11 @@ public class SDBaseOps { * Boolean or array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -196,6 +192,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -220,6 +218,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -246,6 +246,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -269,6 +271,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -744,6 +748,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -762,6 +768,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -964,6 +972,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -982,6 +992,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1032,6 +1044,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1050,6 +1064,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1245,6 +1261,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1263,6 +1281,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1313,6 +1333,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1331,6 +1353,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1581,6 +1605,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1596,6 +1622,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1695,6 +1723,38 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + } + + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -1785,6 +1845,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1800,6 +1862,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1916,6 +1980,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1934,6 +2000,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -4176,6 +4244,32 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + } + + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(String[] names, SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index a58d4d180..ef030e952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -258,7 +258,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -282,7 +282,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -306,7 +306,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. @@ -328,7 +328,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 1f89ba1d1..66d47f905 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -67,13 +67,13 @@ public class SDMath extends SDOps { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); } @@ -82,14 +82,14 @@ public class SDMath extends SDOps { * * @param name name May be null. Name for the output variable * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -166,6 +166,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -1126,6 +1188,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -1552,6 +1676,104 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + } + + /** + * Scalar floor modulus operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -2260,6 +2482,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -2370,6 +2628,78 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -2396,6 +2726,68 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise negative operation: out = -x
* @@ -2542,6 +2934,96 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + } + + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -2566,6 +3048,30 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + } + + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -2616,6 +3122,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -2814,6 +3382,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Standardize input variable along given axis
*


@@ -2894,6 +3498,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 9633a0186..15d70aac5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; public class SDNN extends SDOps { public SDNN(SameDiff sameDiff) { @@ -722,6 +723,39 @@ public class SDNN extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + } + + /** + * Padding operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(String name, SDVariable input, SDVariable padding, PadMode PadMode, + double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Padding operation
* @@ -733,7 +767,7 @@ public class SDNN extends SDOps { public SDVariable pad(SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); } /** @@ -748,7 +782,35 @@ public class SDNN extends SDOps { public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(String name, SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index 88792bddb..fc406caa6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,20 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.indexing.conditions.Condition; -/** - * Abstract class for defining categories of operations - such as {@link SDMath} that is available via {@code SameDiff.math()} - * - * @author Alex Black - */ -public abstract class SDOps { - - protected final SameDiff sd; +public class SDOps { + protected SameDiff sd; public SDOps() { sd = null; @@ -37,11 +38,5 @@ public abstract class SDOps { this.sd = sameDiff; } - protected DifferentialFunctionFactory f() { - return sd.f(); - } - protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { - return sd.updateVariableNameAndReference(varToUpdate, newVarName); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java index 97a47d257..2b91300eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SameDiff; /** * An OpPredicate defines whether an operation ({@link DifferentialFunction}) matches or not.
- * Used mainly in {@link org.nd4j.autodiff.functions.DifferentialFunctionFactory} * * @author Alex Black */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java new file mode 100644 index 000000000..a3f9ddea2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.util; + +import java.util.*; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JException; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Utilities for SameDiff training and inference + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class SameDiffUtils { + + /** + * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} + */ + public static Map stackOutputs(List> outputs){ + Map> outs = new HashMap<>(); + for(Map batch : outputs){ + for(String k : batch.keySet()){ + if(!outs.containsKey(k)) + outs.put(k, new ArrayList()); + outs.get(k).add(batch.get(k)); + } + } + + Map ret = new HashMap<>(); + for(String k : outs.keySet()){ + try { + ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); + } catch(Exception e){ + throw new ND4JException("Error concatenating batch outputs", e); + } + } + return ret; + } + + /** + * Get a list of batch outputs for a single variable from a list of batch outputs for all variables + */ + public static List getSingleOutput(List> outputs, String output){ + List batches = new ArrayList<>(); + for(Map batch : outputs) + batches.add(batch.get(output)); + + return batches; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map externalGradients, SDVariable... inputs) { + Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + + " be specified when using external errors: got %s", inputs); + ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); + fn.outputVariable(); + return fn; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, SDVariable[] inputs) { + return externalErrors(sameDiff, null, inputs); + } + + + + /** + * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. + *

+ * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, + * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. + * This is typically only used with reduction operations backprop. + * + * @param origRank Rank of the original array, before the reduction was executed + * @param reduceDims Dimensions that the original array was reduced from + * @param toExpand Array to add 1s to the shape to (such that it can be + * @return Reshaped array. + */ + public static SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { + if (Shape.isWholeArray(origRank, reduceDims)) { + //Output is [1,1] which is already broadcastable + return toExpand; + } else if (origRank == 2 && reduceDims.length == 1) { + //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] + //both are already broadcastable + return toExpand; + } else { + //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] + for (int d : reduceDims) { + toExpand = toExpand.getSameDiff().expandDims(toExpand, d); + } + return toExpand; + } + } + + public static SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { + SDVariable shape = origInput.shape(); + SDVariable reduceShape = reductionShape(shape, axis, true); + SDVariable reshaped = toExpand.reshape(reduceShape); + return reshaped; + } + + public static SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ + return new ReductionShape(shape.getSameDiff(), shape, axis, keepDim).outputVariable(); + } + + public static void validateDifferentialFunctionSameDiff(SameDiff sameDiff, SDVariable function, DifferentialFunction op) { + + Preconditions.checkState(function != null, "Passed in function was null."); + Preconditions.checkState(function.getSameDiff() == sameDiff); + + Preconditions.checkState(function.getSameDiff() == sameDiff, + "Function applications must be contained " + + "in same sameDiff. The left %s must match this function %s", function, op); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java deleted file mode 100644 index 289bd15be..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.nd4j.autodiff.util; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.exception.ND4JException; -import org.nd4j.linalg.factory.Nd4j; - -/** - * Utilities for SameDiff training and inference - */ -@NoArgsConstructor(access = AccessLevel.PRIVATE) -public class TrainingUtils { - - /** - * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} - */ - public static Map stackOutputs(List> outputs){ - Map> outs = new HashMap<>(); - for(Map batch : outputs){ - for(String k : batch.keySet()){ - if(!outs.containsKey(k)) - outs.put(k, new ArrayList()); - outs.get(k).add(batch.get(k)); - } - } - - Map ret = new HashMap<>(); - for(String k : outs.keySet()){ - try { - ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); - } catch(Exception e){ - throw new ND4JException("Error concatenating batch outputs", e); - } - } - return ret; - } - - /** - * Get a list of batch outputs for a single variable from a list of batch outputs for all variables - */ - public static List getSingleOutput(List> outputs, String output){ - List batches = new ArrayList<>(); - for(Map batch : outputs) - batches.add(batch.get(output)); - - return batches; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java new file mode 100644 index 000000000..4802ebdaf --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Padding format */ +public enum PadMode { + CONSTANT, + + REFLECT, + + SYMMETRIC +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java new file mode 100644 index 000000000..865d23282 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Weights format: [kH, kW, iC, oC] or [oC, iC, kH, kW], or [oC, kH, kW, iC] */ +public enum WeightsFormat { + YXIO, + + OIYX, + + OYXI +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index afdc11aa4..e2ca9329e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -56,8 +57,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,9 +79,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); - this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +103,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index 291b66d2b..56201560a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -57,8 +58,6 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,8 +79,8 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +106,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 8b598242c..502874cc1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,7 +47,6 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); sameDiff.addArgsFor(new SDVariable[]{i_v},this); this.xVertexId = i_v.name(); @@ -65,8 +65,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 9e5b8f67b..66c3e95d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -23,6 +23,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -59,7 +60,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { dimensions = new int[] {Integer.MAX_VALUE}; this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.keepDims = keepDims; this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); @@ -83,8 +84,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.keepDims = keepDims; sameDiff.addArgsFor(new String[]{xVertexId,yVertexId},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 8cb7e50b4..858b6a81c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,7 +74,7 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { if (i_v != null) { this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } else { throw new IllegalArgumentException("Input not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index 254069929..66a204602 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -94,7 +95,7 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java index 4e498edeb..7f8e0487e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -52,8 +53,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { boolean inPlace) { super(sameDiff,inPlace,new Object[] {i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.inPlace = inPlace; this.xVertexId = i_v1.name(); @@ -77,8 +78,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,extraArgs); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.xVertexId = i_v1.name(); this.yVertexId = i_v2.name(); @@ -104,7 +105,7 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,inPlace,extraArgs); if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new SDVariable[]{i_v},this); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index b4cf2d05a..692571df9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -20,6 +20,7 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -38,6 +39,10 @@ public class NoOp extends DynamicCustomOp { super("noop", sd, new SDVariable[]{in}); } + public NoOp(INDArray in) { + addInputArgument(in); + } + @Override public List doDiff(List f1) { return Collections.singletonList(f1.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index f1b7f7398..cea3b388a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -85,7 +85,7 @@ public class BiasAdd extends DynamicCustomOp { @Override public List doDiff(List gradient){ - return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw)); + return new BiasAddGrad(sameDiff, arg(0), arg(1), gradient.get(0), nchw).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 1bb451bf1..c1aff757d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; @@ -41,6 +42,10 @@ public abstract class BaseCompatOp extends DynamicCustomOp { super(null, sameDiff, inputs); } + public BaseCompatOp(INDArray... inputs) { + addInputArgument(inputs); + } + public BaseCompatOp(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 993a5b11e..9adbd78df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -36,6 +37,10 @@ public class Merge extends BaseCompatOp { super(sd, inputs); } + public Merge(INDArray... inputs) { + super(inputs); + } + public Merge(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java index c7d90e4c8..f302c752a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java @@ -50,6 +50,6 @@ public class StopGradient extends BaseDynamicTransformOp { @Override public List doDiff(List gradients){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 1b6c2f5e2..a7804f39f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -44,6 +45,10 @@ public class Switch extends BaseCompatOp { this.predicate = predicate; } + public Switch(INDArray input, INDArray predicate) { + addInputArgument(input, predicate); + } + public Switch(){ } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index b9c3962aa..181321d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -74,6 +74,6 @@ public class IAMax extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 63f40ee6c..760fca314 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -76,6 +76,6 @@ public class IAMin extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 8b7872b49..c01be78f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -83,6 +83,6 @@ public class IMax extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index 06b3deb1c..e668f1ee0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -77,6 +77,6 @@ public class IMin extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 9635c6f36..5417b14cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -60,7 +60,6 @@ public class Conv2D extends DynamicCustomOp { SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); } - @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -71,7 +70,7 @@ public class Conv2D extends DynamicCustomOp { } public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ - super(inputs, outputs); + super(inputs, outputs); initConfig(config); } @@ -103,7 +102,8 @@ public class Conv2D extends DynamicCustomOp { config.getDH(), config.getDW(), ArrayUtil.fromBoolean(config.isSameMode()), - config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + config.getDataFormat().equalsIgnoreCase("NCHW") ? 0 : 1, + config.getWeightsFormat().ordinal()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 436659443..d0b04b36a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -161,8 +161,7 @@ public class DeConv3D extends DynamicCustomOp { @Override public List doDiff(List f1) { SDVariable bias = args().length > 2 ? arg(2) : null; - SDVariable[] outVars = f().deconv3dDerivative(arg(0), arg(1), bias, f1.get(0), config); - return Arrays.asList(outVars); + return new DeConv3DDerivative(sameDiff, arg(0), arg(1), bias, f1.get(0), config).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java index 46f5f4e79..86bfbacc9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java @@ -90,7 +90,7 @@ public class Im2col extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().im2ColBp(arg(), grad.get(0), conv2DConfig)); + return new Im2colBp(sameDiff, arg(), grad.get(0), conv2DConfig).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index df345a2f3..3370b6f30 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -99,7 +99,7 @@ public class Upsampling2d extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().upsampling2dBp(arg(), f1.get(0), nchw, scaleH, scaleW)); + return new Upsampling2dDerivative(sameDiff, arg(), f1.get(0), nchw, scaleH, scaleW).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java index 40a2a3908..92701a696 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java @@ -22,6 +22,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.base.Preconditions; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.util.ConvConfigUtil; @Data @@ -50,9 +51,11 @@ public class Conv2DConfig extends BaseConvolutionConfig { private boolean isSameMode; @Builder.Default private String dataFormat = NCHW; + @Builder.Default + private WeightsFormat weightsFormat = WeightsFormat.YXIO; public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode, - String dataFormat) { + String dataFormat, WeightsFormat weightsFormat) { this.kH = kH; this.kW = kW; @@ -64,6 +67,7 @@ public class Conv2DConfig extends BaseConvolutionConfig { this.dW = dW; this.isSameMode = isSameMode; this.dataFormat = dataFormat; + this.weightsFormat = weightsFormat; validate(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index adc59e4e0..4f6539eee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; import java.util.Arrays; import java.util.List; @@ -58,7 +59,6 @@ public class AbsoluteDifferenceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossAbsoluteDifferenceBP(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 7faa5f6b0..432910391 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; import java.util.Arrays; import java.util.List; @@ -61,8 +62,7 @@ public class CosineDistanceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient. //Args are: predictions, weights, label - SDVariable[] grads = f().lossCosineDistanceBp(arg(2), arg(0), arg(1), lossReduce, dimension); - return Arrays.asList(grads); + return new CosineDistanceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index 5d85e4933..d021623d5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class HingeLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHingeBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new HingeLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index f08d90566..acb74c04c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; import java.util.Arrays; import java.util.List; @@ -63,8 +64,7 @@ public class HuberLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHuberBp(arg(2), arg(0), arg(1), lossReduce, delta); - return Arrays.asList(grads); + return new HuberLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), delta).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index e1fe56e5f..d36d36c2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -62,6 +62,6 @@ public class L2Loss extends DynamicCustomOp { public List doDiff(List grad){ //L2 loss: L = 1/2 * sum(x_i^2) //dL/dxi = xi - return Collections.singletonList(f().identity(arg())); + return Collections.singletonList(sameDiff.identity(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index a7a15f1b5..c13634ee1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; import java.util.Arrays; import java.util.List; @@ -64,8 +65,7 @@ public class LogLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossLogBp(arg(2), arg(0), arg(1), lossReduce, epsilon); - return Arrays.asList(grads); + return new LogLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), epsilon).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index a893e3f4a..2ec6e54b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; import java.util.Arrays; import java.util.List; @@ -73,14 +74,7 @@ public class LogPoissonLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - - SDVariable[] grads; - if(full) { - grads = f().lossLogPoissonFullBp(arg(2), arg(0), arg(1), lossReduce); - }else{ - grads = f().lossLogPoissonBp(arg(2), arg(0), arg(1), lossReduce); - } - return Arrays.asList(grads); + return new LogPoissonLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), full).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 6c3c5d01b..676eec5e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -54,7 +55,6 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanPairwiseSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanPairwiseSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index a9cf27584..c40d9e432 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class MeanSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 214380a8c..862b405d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -27,6 +27,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -80,7 +81,6 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SigmoidCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 57576b78f..e97427e92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -25,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -99,7 +100,6 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSoftmaxCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java index 3ef7de264..defb8292b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -73,8 +74,6 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { public List doDiff(List grad){ //No external gradient //Args: logits, weigths, label - SDVariable[] args = args(); - SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff, arg(0), arg(1), classesDim).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index a0f3288a9..c58933134 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -96,7 +97,8 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { @Override public List doDiff(List grad){ //args: label, logits - SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0)); - return Arrays.asList(f().zerosLike(arg(0)), ret[0]); + SDVariable labelsGrad = sameDiff.zerosLike(arg(0)); + SDVariable logitsGrad = new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff, arg(1), arg(0)).outputVariable(); + return Arrays.asList(labelsGrad, logitsGrad); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 46310893d..479c794c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -266,7 +266,7 @@ public class Mmul extends DynamicCustomOp { @Override public List doDiff(List gradients) { - return sameDiff.f().mmulBp(larg(),rarg(), gradients.get(0), mt); + return Arrays.asList(new MmulBp(sameDiff, larg(), rarg(), gradients.get(0), mt).outputVariables()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index bf3cb4af1..89ba1549b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -25,6 +25,8 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -83,8 +85,8 @@ public class Moments extends DynamicCustomOp { public List doDiff(List grad){ SDVariable dLdMean = grad.get(0); SDVariable dLdVar = grad.get(1); //Note: non-bias-corrected variance - SDVariable meanBp = f().meanBp(arg(), dLdMean, false, axes); - SDVariable varBp = f().varianceBp(arg(), dLdVar, false, false, axes); + SDVariable meanBp = new MeanBp(sameDiff, arg(), dLdMean, false, axes).outputVariable(); + SDVariable varBp = new VarianceBp(sameDiff, arg(), dLdVar, false, false, axes).outputVariable(); return Collections.singletonList(meanBp.add(varBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index c613f107f..f58347492 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -138,13 +138,13 @@ public class TensorMmul extends DynamicCustomOp { //tensor matrix multiply gradient wrt second variable int[] firstPerm = argsort(combine(deletedAxes[0],keep(argsort(sumAxes[1]),sumAxes[0]))); SDVariable firstResult = doTensorMmul(i_v1.get(0), rarg(), firstAxes); - SDVariable permuted = f().permute(firstResult,firstPerm); + SDVariable permuted = sameDiff.permute(firstResult,firstPerm); ret.add(permuted); //tensor matrix multiply gradient wrt first variable int[] secondPerm = argsort(combine(keep(argsort(sumAxes[0]),sumAxes[1]),deletedAxes[1])); SDVariable secondResult = doTensorMmul(i_v1.get(0), larg(), secondAxes); - SDVariable secondPermuted = f().permute(secondResult,secondPerm); + SDVariable secondPermuted = sameDiff.permute(secondResult,secondPerm); ret.add(secondPermuted); return ret; } @@ -210,7 +210,7 @@ public class TensorMmul extends DynamicCustomOp { } - int[] newShapeB = {n3, -1}; + long[] newShapeB = {n3, -1}; long[] oldShapeB; if (listB.size() == 0) { oldShapeB = new long[] {1}; @@ -221,16 +221,12 @@ public class TensorMmul extends DynamicCustomOp { } - SDVariable at = f() - .reshape(f().permute - (a,newAxesA),newShapeA); - SDVariable bt = f() - .reshape(f() - .permute(b,newAxesB),newShapeB); + SDVariable at = sameDiff.reshape(sameDiff.permute(a,newAxesA),newShapeA); + SDVariable bt = sameDiff.reshape(sameDiff.permute(b,newAxesB),newShapeB); - SDVariable ret = f().mmul(at,bt); + SDVariable ret = sameDiff.mmul(at,bt); long[] aPlusB = Longs.concat(oldShapeA, oldShapeB); - return f().reshape(ret, aPlusB); + return sameDiff.reshape(ret, aPlusB); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java index a465728d1..8aa12d4d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java @@ -57,7 +57,7 @@ public class All extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index d4522ca69..4d26e5b70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -57,7 +57,7 @@ public class Any extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index cb93a832e..5dfc23f8e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -68,7 +68,7 @@ public class IsInf extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index c8cd72f2c..a78ae8bd5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -68,7 +68,7 @@ public class IsNaN extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 26eabf0ff..edc3298b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -82,10 +83,10 @@ public class LogSumExp extends DynamicCustomOp { //z = log(sum_i exp(x_i)) = log(s) //dL/dx = dL/dz * dz/ds * ds/dx //dz/ds = 1/s - SDVariable exp = f().exp(arg()); + SDVariable exp = sameDiff.math.exp(arg()); SDVariable sumExp = exp.sum(dimensions); SDVariable gradProd = f1.get(0).div(sumExp); - SDVariable dSumExpdx = f().sumBp(arg(), gradProd, keepDims, dimensions).mul(exp); + SDVariable dSumExpdx = new SumBp(sameDiff, arg(), gradProd, keepDims, dimensions).outputVariable().mul(exp); return Collections.singletonList(dSumExpdx); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java index 47cc728ab..e9481fa81 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -73,7 +74,7 @@ public class AMean extends BaseReduceFloatOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().meanBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new MeanBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java index bb0dd4997..913a573db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java @@ -16,12 +16,11 @@ package org.nd4j.linalg.api.ops.impl.reduce.floating; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -70,13 +69,13 @@ public class Entropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log(x)) = log(x)+1 - return grad(f(), arg(), f1.get(0), dimensions); + return grad(sameDiff, arg(), f1.get(0), dimensions); } - public static List grad(DifferentialFunctionFactory f, SDVariable arg, SDVariable grad, int[] dimensions){ - SDVariable logx = f.log(arg); + public static List grad(SameDiff sd, SDVariable arg, SDVariable grad, int[] dimensions){ + SDVariable logx = sd.math.log(arg); SDVariable xLogX = arg.mul(logx); - SDVariable sumBp = f.sumBp(xLogX, grad.neg(), false, dimensions); + SDVariable sumBp = new SumBp(sd, xLogX, grad.neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java index 52970cc33..837d89c3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java @@ -70,7 +70,7 @@ public class LogEntropy extends BaseReduceFloatOp { @Override public List doDiff(List f1) { //If y=log(x), and x=entropy(in) then dL/dx = dL/dy * dy/dx; d(log(x))/dx = 1/x - List entropyGrad = Entropy.grad(f(), arg(), f1.get(0), dimensions); - return Collections.singletonList(entropyGrad.get(0).div(f().exp(outputVariable()))); + List entropyGrad = Entropy.grad(sameDiff, arg(), f1.get(0), dimensions); + return Collections.singletonList(entropyGrad.get(0).div(sameDiff.math.exp(outputVariable()))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index bf15f94d4..6309ccf28 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -67,7 +68,7 @@ public class Mean extends BaseReduceFloatOp { public List doDiff(List i_v1) { //If out = mean(in), then dL/dIn = 1/N * dL/dOut (broadcast to appropriate shape) //Note that N differs for "along dimension" vs. "whole array" reduce cases - return Collections.singletonList(f().meanBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new MeanBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java index a2ba88927..96222d7c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -80,6 +81,6 @@ public class Norm1 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().norm1Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm1Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java index f61c0dc43..be517f5e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -72,7 +73,7 @@ public class Norm2 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { //d norm2(in)/dx = x / norm2(in) - return Collections.singletonList(f().norm2Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm2Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java index ece542857..a7cf398f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -77,7 +78,7 @@ public class NormMax extends BaseReduceFloatOp { public List doDiff(List grad) { //maxnorm(in) = max_i |x_i| //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise - return Collections.singletonList(f().normmaxBp(arg(), grad.get(0), keepDims, dimensions)); + return new NormMaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java index 44504f855..963224ed8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -68,10 +69,10 @@ public class ShannonEntropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log2(x)) = (log(x)+1)/log(2) - SDVariable log2x = f().log(arg(),2); - SDVariable logx = f().log(arg()); + SDVariable log2x = sameDiff.math.log(arg(),2); + SDVariable logx = sameDiff.math.log(arg()); SDVariable xLog2X = arg().mul(log2x); - SDVariable sumBp = f().sumBp(xLog2X, f1.get(0).neg(), false, dimensions); + SDVariable sumBp = new SumBp(sameDiff, xLog2X, f1.get(0).neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0)).div(Math.log(2.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index b11fe5b1f..2af86c181 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; import java.util.Collections; import java.util.List; @@ -69,6 +70,6 @@ public class SquaredNorm extends BaseReduceFloatOp { @Override public List doDiff(List grad){ - return Collections.singletonList(f().squaredNormBp(arg(), grad.get(0), keepDims, dimensions)); + return new SquaredNormBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java index d27215a80..7376b0708 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java @@ -56,7 +56,7 @@ public class CountNonZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java index db13dfc85..27476cabc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java @@ -67,7 +67,7 @@ public class CountZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java index a6533441a..cadc77d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -65,7 +66,7 @@ public class AMax extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable maxBp = f().maxBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable maxBp = new MaxBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(maxBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java index 20a8be906..a01c9c1f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -69,7 +70,7 @@ public class AMin extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java index 17d8a0bde..1a15c32ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -72,7 +73,7 @@ public class ASum extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().sumBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new SumBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java index a29384a42..8c4563c95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -79,7 +80,7 @@ public class Max extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().maxBp(arg(), grad.get(0), keepDims, dimensions)); + return new MaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java index 99c1e038b..1d644b671 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -77,6 +78,6 @@ public class Min extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().minBp(arg(), grad.get(0), keepDims, dimensions)); + return new MinBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java index b5073d0f9..0247e3169 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; import java.util.Collections; import java.util.List; @@ -82,6 +83,6 @@ public class Prod extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().prodBp(arg(), grad.get(0), keepDims, dimensions)); + return new ProdBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index 859b89dac..e6fa79bb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -76,7 +77,7 @@ public class Sum extends BaseReduceSameOp { // dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * 1 // But broadcast to shape of the input - return Collections.singletonList(f().sumBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new SumBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java index 44f1c49fc..a5aab468b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java @@ -84,7 +84,7 @@ public class CosineDistance extends BaseReduce3Op { //Cosine distance = 1 - cosine similarity //Therefore: just need to negate gradients from cosine similarity... - List diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); - return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1))); + List diff = CosineSimilarity.doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return Arrays.asList(sameDiff.math.neg(diff.get(0)), sameDiff.math.neg(diff.get(1))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java index 27c14473d..b6edbe6fa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ops.impl.reduce3; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -93,14 +93,14 @@ public class CosineSimilarity extends BaseReduce3Op { //Then: // dc(x,y)/dx_i = 1/b * (y - x * a / (l2(x))^2) - return doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); } - public static List doDiff(SameDiff sameDiff, DifferentialFunctionFactory f, SDVariable x, SDVariable y, + public static List doDiff(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable gradOut, boolean keepDims, int... dimensions){ SDVariable a = sameDiff.sum(x.mul(y),true, dimensions); - SDVariable l2x = f.norm2(x, true, dimensions); - SDVariable l2y = f.norm2(y, true, dimensions); + SDVariable l2x = sameDiff.norm2(x, true, dimensions); + SDVariable l2y = sameDiff.norm2(y, true, dimensions); SDVariable b = l2x.mul(l2y); SDVariable l2xSq = sameDiff.math().square(l2x); @@ -110,7 +110,7 @@ public class CosineSimilarity extends BaseReduce3Op { //keepDims or full array reduction broadcastableGrad = gradOut; } else { - broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); + broadcastableGrad = SameDiffUtils.reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); } SDVariable dcdx = y.sub(x.mul(a).div(l2xSq)).div(b); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java index 85f0b3e15..bdb172924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; import java.util.Arrays; import java.util.List; @@ -86,6 +87,6 @@ public class Dot extends BaseReduce3Op { @Override public List doDiff(List f1) { //TODO KEEP DIMS - return Arrays.asList(f().dotBp(arg(0), arg(1), f1.get(0), false, dimensions)); + return new DotBp(sameDiff, arg(0), arg(1), f1.get(0), false, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java index a25ba6d52..97ccd81e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -89,11 +90,11 @@ public class EuclideanDistance extends BaseReduce3Op { SDVariable divBroadcastable = i_v1.get(0).div(euc); if(!keepDims && !(dimensions == null || dimensions.length == 0 || (dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE))){ //Not keep dims, and not full array reduction -> need to make broadcastable - divBroadcastable = f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); + divBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); } SDVariable gradX = difference.mul(divBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math.neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java index 994003e78..c520a7c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -90,18 +91,18 @@ public class JaccardDistance extends BaseReduce3Op { //Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance //J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i) - SDVariable min = f().min(larg(), rarg()); - SDVariable max = f().max(larg(), rarg()); + SDVariable min = sameDiff.math.min(larg(), rarg()); + SDVariable max = sameDiff.math.max(larg(), rarg()); SDVariable sumMax = max.sum(true, dimensions); SDVariable sumMin = min.sum(true, dimensions); DataType d = arg().dataType(); - SDVariable xIsMin = f().eq(min, larg()).castTo(d); - SDVariable xIsMax = f().eq(max, larg()).castTo(d); - SDVariable yIsMin = f().eq(min, rarg()).castTo(d); - SDVariable yIsMax = f().eq(max, rarg()).castTo(d); + SDVariable xIsMin = sameDiff.eq(min, larg()).castTo(d); + SDVariable xIsMax = sameDiff.eq(max, larg()).castTo(d); + SDVariable yIsMin = sameDiff.eq(min, rarg()).castTo(d); + SDVariable yIsMax = sameDiff.eq(max, rarg()).castTo(d); - SDVariable sqSumMax = f().square(sumMax); + SDVariable sqSumMax = sameDiff.math.square(sumMax); SDVariable dldx = xIsMax.mul(sumMin).sub(xIsMin.mul(sumMax)).div(sqSumMax); SDVariable dldy = yIsMax.mul(sumMin).sub(yIsMin.mul(sumMax)).div(sqSumMax); @@ -110,7 +111,7 @@ public class JaccardDistance extends BaseReduce3Op { //KeepDims or full array reduction - already broadcastable bcGradOut = f1.get(0); } else { - bcGradOut = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); + bcGradOut = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); } return Arrays.asList(dldx.mul(bcGradOut), dldy.mul(bcGradOut)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java index 0c007a261..9fdea3afb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -86,11 +87,11 @@ public class ManhattanDistance extends BaseReduce3Op { //keepDims or full array reduction gradBroadcastable = i_v1.get(0); } else { - gradBroadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); + gradBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); } SDVariable gradX = sameDiff.math().sign(difference).mul(gradBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math().neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 000b0414c..44514ee1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -108,7 +109,7 @@ public class LeakyReLU extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha)); + return new LeakyReLUBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java index f9e30be9c..4b9b37026 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java @@ -29,6 +29,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; /** * Parameterized ReLU op @@ -80,6 +81,6 @@ public class PRelu extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes)); + return new PReluBp(sameDiff, arg(0), arg(1), i_v.get(0), sharedAxes).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 5cfab3768..ec15ea537 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -87,7 +88,7 @@ public class Pow extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, this.pow).outputVariable().mul(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 944d4d095..98df920bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; import java.util.Arrays; import java.util.Collections; @@ -81,6 +82,6 @@ public class RectifiedLinear extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0))); + return new ThresholdReluBp(sameDiff, arg(), i_v.get(0), scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java index c80d3c8f9..9b11925c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java @@ -23,6 +23,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -99,6 +100,6 @@ public class Relu6 extends BaseScalarOp { @Override public List doDiff(List i_v) { SDVariable dLdOut = i_v.get(0); - return Collections.singletonList(f().relu6Derivative(arg(), dLdOut, scalarValue.getDouble(0))); + return new Relu6Derivative(sameDiff, arg(), dLdOut, scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java index 3aa31771a..f8831c68e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scalar; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -49,9 +50,12 @@ public class ScalarAdd extends BaseScalarOp { this(arr, 0); } + public ScalarAdd(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, Number scalar) { + this(sameDiff, i_v, scalar, false); + } + public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { super(sameDiff, i_v, scalar, inPlace); - } public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace, Object[] extraArgs) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java index 1fec8f808..463012875 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class ScalarReverseDivision extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().rdiv(f().pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.rdiv(sameDiff.math.pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); + return Collections.singletonList(g); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java index d362620e4..972f4ec10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -79,8 +80,8 @@ public class ScalarReverseSubtraction extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().neg(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.neg(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java index a8ec8c7f3..d3a8c7f67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java @@ -76,7 +76,7 @@ public class ScalarSet extends BaseScalarOp { @Override public List doDiff(List i_v1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java index 65f653d64..04bd39622 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java @@ -96,6 +96,6 @@ public class Step extends BaseScalarOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 160556867..1846ab8f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -88,9 +88,9 @@ public class ScatterAdd extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 5d6b60c88..75badc9c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -77,13 +77,13 @@ public class ScatterDiv extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterDiv(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterDiv(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); - SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(f().square(updates)).neg(); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); + SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(sameDiff.math.square(updates)).neg(); ret.add(updateGrad); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 7f814d928..2dead9742 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -87,12 +87,12 @@ public class ScatterMax extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 2539a3d56..4af8a2cd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -88,12 +88,12 @@ public class ScatterMin extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 411c59188..48e1a00bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -91,12 +91,12 @@ public class ScatterMul extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); SDVariable updateGrad = gatherOutGrad.mul(gatherRef); ret.add(updateGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 83c4cc222..f66f7d689 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -74,9 +74,9 @@ public class ScatterSub extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather.neg()); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index 93e1e5995..c5644faa5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -98,12 +98,12 @@ public class ScatterUpdate extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable zerosUpdate = f().zerosLike(updates); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize + SDVariable zerosUpdate = sameDiff.zerosLike(updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize ret.add(gradRef); //Reference array gradient - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), indices, 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), indices, 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index a9f964844..ede375163 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -117,6 +117,6 @@ public class Linspace extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); + return Arrays.asList(sameDiff.zerosLike(arg(0)), sameDiff.zerosLike(arg(1)), sameDiff.zerosLike(arg(2))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index cfd0bd7ed..870a72a2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -77,10 +77,10 @@ public class Permute extends Transpose { SDVariable ret; if(args().length == 1) { //Static dimensions - ret = f().permute(i_v.get(0), reverseDims); + ret = sameDiff.permute(i_v.get(0), reverseDims); } else { //Dynamic dimensions - ret = f().permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); + ret = sameDiff.permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); } return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 2126dfe27..6c5c0f9d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -152,8 +152,8 @@ public class Reshape extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable origShape = f().shape(arg()); - SDVariable ret = f().reshape(i_v.get(0), origShape); + SDVariable origShape = sameDiff.shape(arg()); + SDVariable ret = sameDiff.reshape(i_v.get(0), origShape); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 3c3baf1f6..94a6e6c2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -120,7 +120,7 @@ public class SequenceMask extends DynamicCustomOp { @Override public List doDiff(List grad){ //Input is integer indices - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java index 593830327..d9c4c4578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java @@ -65,7 +65,7 @@ public class ShapeN extends DynamicCustomOp { public List doDiff(List i_v) { List out = new ArrayList<>(); for(SDVariable in : args()){ - out.add(f().zerosLike(in)); + out.add(sameDiff.zerosLike(in)); } return out; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index 46b8f6286..b9bcff540 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import java.util.*; @@ -82,10 +83,10 @@ public class Slice extends DynamicCustomOp { @Override public List doDiff(List grad) { if(args().length == 1) { - return Collections.singletonList(f().sliceBp(arg(), grad.get(0), begin, size)); + return new SliceBp(sameDiff, arg(), grad.get(0), begin, size).outputs(); } else { //Dynamic begin/size - return Collections.singletonList(f().sliceBp(arg(0), grad.get(0), arg(1), arg(2))); + return new SliceBp(sameDiff, arg(0), grad.get(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 17a8beb3c..28e92930c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -129,7 +129,7 @@ public class Stack extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().unstack(f1.get(0), jaxis, args().length)); + return Arrays.asList(sameDiff.unstack(f1.get(0), jaxis, args().length)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 456edfe1c..33c79e217 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -27,6 +27,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -259,12 +260,12 @@ public class StridedSlice extends DynamicCustomOp { public List doDiff(List i_v) { if(args().length == 1) { //Array inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), begin, end, strides, beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), begin, end, strides, beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } else { //SDVariable inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index c2e476f60..e90e31427 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -24,6 +24,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -126,9 +127,9 @@ public class Tile extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(jaxis != null){ - return Collections.singletonList(f().tileBp(arg(), i_v.get(0), jaxis)); + return new TileBp(sameDiff, arg(), i_v.get(0), jaxis).outputs(); }else{ - return Collections.singletonList(f().tileBp(arg(0), arg(1), i_v.get(0))); + return new TileBp(sameDiff, arg(0), arg(1), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java index 211fec834..2b6359985 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -101,7 +102,7 @@ public class StandardDeviation extends Variance { //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) - return Collections.singletonList(f().stdBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new StandardDeviationBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index adc92549b..64948880c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceOp; import org.nd4j.linalg.api.ops.OpContext; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -115,7 +116,7 @@ public class Variance extends BaseReduceOp { //If out = var(in) then: //dL/dIn = dL/dOut * dOut/dIn // with dOut/dIn = (in-mean) * 2/(n-1) - return Collections.singletonList(f().varianceBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new VarianceBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java index 66eeb9b99..7fed19e02 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java @@ -44,6 +44,6 @@ public class Angle extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index b7bd0e0f6..ef8283be8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -46,6 +47,24 @@ public class Pad extends DynamicCustomOp { public Pad(){ } + private static Mode adaptMode(PadMode mode) { + Mode legacyMode = Mode.CONSTANT; + + if (mode == PadMode.CONSTANT) { + legacyMode = Mode.CONSTANT; + } + else if (mode == PadMode.REFLECT) { + legacyMode = Mode.REFLECT; + } + else if (mode == PadMode.SYMMETRIC) { + legacyMode = Mode.SYMMETRIC; + } + return legacyMode; + } + + public Pad(SameDiff sd, SDVariable in, SDVariable padding, PadMode mode, double padValue) { + this(sd, in, padding, adaptMode(mode), padValue); + } public Pad(SameDiff sd, SDVariable in, SDVariable padding, Mode mode, double padValue) { super(sd, new SDVariable[]{in, padding}, false); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -62,6 +81,10 @@ public class Pad extends DynamicCustomOp { this(in, padding, null, Mode.CONSTANT, padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, @NonNull PadMode mode, double padValue) { + this(in, padding, null, adaptMode(mode), padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -70,6 +93,10 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull PadMode mode, double padValue) { + this(in, padding, out, adaptMode(mode), padValue); + } + @Override public String opName(){ return "pad"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java index 218ec66db..54f91f6d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java @@ -73,7 +73,7 @@ public class IsMax extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java index 4dd948b4f..053199731 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java @@ -75,6 +75,6 @@ public class BooleanNot extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 8df844943..8cf81febf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -73,7 +73,7 @@ public class IsFinite extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java index 44cb362a4..95d75e9be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java @@ -73,7 +73,7 @@ public class IsInf extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index daf9b0ea3..9f8e9ea74 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -74,7 +74,7 @@ public class IsNaN extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index fadd8720e..ef1ebb38a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -73,7 +73,7 @@ public class ClipByNorm extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable()); + return new ClipByNormBp(sameDiff, arg(), grad.get(0), clipValue, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index fa465b251..44cde0abb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -83,8 +83,8 @@ public class ClipByValue extends DynamicCustomOp { @Override public List doDiff(List grad) { //dOut/dIn is 0 if clipped, 1 otherwise - SDVariable notClippedLower = f().gt(arg(), clipValueMin).castTo(arg().dataType()); - SDVariable notClippedUpper = f().lt(arg(), clipValueMax).castTo(arg().dataType()); + SDVariable notClippedLower = sameDiff.gt(arg(), clipValueMin).castTo(arg().dataType()); + SDVariable notClippedUpper = sameDiff.lt(arg(), clipValueMax).castTo(arg().dataType()); SDVariable ret = notClippedLower.mul(notClippedUpper).mul(grad.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index d6230e153..e847142a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -85,17 +85,7 @@ public class ATan2 extends BaseDynamicTransformOp { SDVariable y = larg(); SDVariable x = rarg(); -/* SDVariable r = y.div(x); - - SDVariable dOutdr = f().square(r).add(1.0).rdiv(1.0); - SDVariable drdy = x.rdiv(1.0); - SDVariable drdx = f().neg(y).div(f().square(x)); - - SDVariable xGrad = dOutdr.mul(drdx).mul(i_v.get(0)); - SDVariable yGrad = dOutdr.mul(drdy).mul(i_v.get(0)); -*/ - - val xGrad = f().neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); + val xGrad = sameDiff.math.neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); val yGrad = x.div(x.pow(2).add(y.pow(2))).mul(i_v.get(0)); return Arrays.asList(yGrad, xGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java index 35c209870..ca466ae34 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java @@ -89,7 +89,7 @@ public class Assign extends DynamicCustomOp { @Override public List doDiff(List f1){ //TODO replace with assign backprop op from libnd4j (that handles the broadcast case properly) - return Arrays.asList(f().zerosLike(larg()), f1.get(0)); + return Arrays.asList(sameDiff.zerosLike(larg()), f1.get(0)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 0be0b08ad..34e3e5f1d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumProd extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumprodBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumProdBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index c24693b01..97c53f4e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -29,6 +29,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumSum extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumsumBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumSumBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java index d3a5c9676..300f8277a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java @@ -70,7 +70,8 @@ public class DotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().dotProductAttentionBp(arg(0), arg(1), arg(2), gradient.get(0), args().length > 3 ? arg(3) : null, scaled); + SDVariable mask = args().length == 4 ? arg(3) : null; + return Arrays.asList(new DotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), mask, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 3efc13af0..718120bf7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -74,7 +75,7 @@ public class DynamicPartition extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().dynamicPartitionBp(arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions)); + return new DynamicPartitionBp(sameDiff, arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions).outputs(); } protected void addArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 94c34d108..60b2bf942 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -83,7 +83,7 @@ public class DynamicStitch extends DynamicCustomOp { SDVariable[] partition = sameDiff.dynamicPartition(gradient, partitions, numPartitions); List ret = new ArrayList<>(); for (SDVariable i : indices) - ret.add(f().zerosLike(i)); + ret.add(sameDiff.zerosLike(i)); Collections.addAll(ret, partition); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 6048c9dff..9612c4dea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -67,8 +68,8 @@ public class InvertPermutation extends BaseDynamicTransformOp { @Override public List doDiff(List grad) { SDVariable gradient = grad.get(0); - SDVariable invertedGradient = f().invertPermutation(gradient, false); - return Arrays.asList(invertedGradient); + SDVariable invertedGradient = sameDiff.invertPermutation(gradient); + return Collections.singletonList(invertedGradient); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 0c4990bb2..f16a92318 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -100,13 +100,11 @@ public class LayerNorm extends DynamicCustomOp { @Override public List doDiff(List gradient) { - SDVariable[] ret; - if(noBias){ - ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions); - }else{ - ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions); + if (noBias) { + return new LayerNormBp(sameDiff, arg(0), arg(1), gradient.get(0), channelsFirst, dimensions).outputs(); + } else { + return new LayerNormBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions).outputs(); } - return Arrays.asList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index 86c9d9c0a..2de57451f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; @@ -76,11 +77,9 @@ public class LogSoftMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(dimension == null) { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } else { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 19d139cbb..9f4b97576 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -57,8 +57,8 @@ public class MatrixSetDiag extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable grad = i_v.get(0); - SDVariable in1Grad = f().setDiag(grad, sameDiff.zerosLike(arg(1))); - SDVariable in2Grad = f().diagPart(grad); + SDVariable in1Grad = sameDiff.math.setDiag(grad, sameDiff.zerosLike(arg(1))); + SDVariable in2Grad = sameDiff.math.diagPart(grad); return Arrays.asList(in1Grad, in2Grad); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java index 54167bd8b..98765ed96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java @@ -79,7 +79,7 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().multiHeadDotProductAttentionBp(arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled); + return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index e155a4f2a..0f8286769 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.PowBp; import java.util.Arrays; import java.util.Collections; @@ -68,8 +69,7 @@ public class Pow extends DynamicCustomOp { SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0)); return Arrays.asList(dlda, dldb);*/ - SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0)); - return Arrays.asList(g); + return new PowBp(sameDiff, arg(0), arg(1), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index d1648abab..372f96c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -100,8 +100,8 @@ public class Reverse extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverse(f1.get(0), dimensions); - return Arrays.asList(ret); + SDVariable ret = sameDiff.reverse(f1.get(0), dimensions); + return Collections.singletonList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 11897fef8..50332daf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -115,8 +115,8 @@ public class ReverseSequence extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverseSequence(f1.get(0), arg(1), seqDim, batchDim); - return Arrays.asList(ret, f().zerosLike(arg(1))); + SDVariable ret = sameDiff.reverseSequence(f1.get(0), arg(1), seqDim, batchDim); + return Arrays.asList(ret, sameDiff.zerosLike(arg(1))); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 24c2353c1..737e76f3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import java.util.Collections; import java.util.List; @@ -106,8 +107,7 @@ public class SoftMax extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), this.dimension); - return Collections.singletonList(ret); + return new SoftmaxBp(sameDiff, arg(), i_v.get(0), this.dimension).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java index 467b36a4e..8acef4029 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java @@ -63,8 +63,7 @@ public class Standardize extends DynamicCustomOp { @Override public List doDiff(List grad) { - SDVariable ret = f().standardizeBp(arg(0), grad.get(0), dimensions); - return Arrays.asList(ret); + return new StandardizeBp(sameDiff, arg(0), grad.get(0), dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java index 82e2ae6e3..1688c03c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; /** * Threshold ReLU op. The genral case of {@link RectifiedLinear}. @@ -72,6 +73,6 @@ public class ThresholdRelu extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff)); + return new ThresholdReluBp(sameDiff, arg(), f1.get(0), cutoff).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index 24d79f234..9faca3403 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -51,12 +51,12 @@ public class Trace extends DynamicCustomOp { @Override public List doDiff(List gradAtOutput){ - SDVariable rows = f().reshape(f().sizeAt(arg(), -2), new long[]{1}); - SDVariable cols = f().reshape(f().sizeAt(arg(), -1), new long[]{1}); - SDVariable eye = sameDiff.math().eye(/*f().shape(gradAtOutput.get(0)),*/ rows, cols); + SDVariable rows = sameDiff.reshape(sameDiff.sizeAt(arg(), -2), 1); + SDVariable cols = sameDiff.reshape(sameDiff.sizeAt(arg(), -1), 1); + SDVariable eye = sameDiff.math().eye(/*sameDiff.shape(gradAtOutput.get(0)),*/ rows, cols); //Reshape gradient from [x,y,z] to [x,y,z,1,1] - SDVariable reshapedGrad = f().expandDims(gradAtOutput.get(0), -1); - reshapedGrad = f().expandDims(reshapedGrad, -1); + SDVariable reshapedGrad = sameDiff.expandDims(gradAtOutput.get(0), -1); + reshapedGrad = sameDiff.expandDims(reshapedGrad, -1); return Collections.singletonList(reshapedGrad.mul(eye)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index 5b6cd2517..c98cd7d5b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMaxBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index d0a9a6784..eca108b2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMeanBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index 2bc369f2a..f070dc8d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMinBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 3be3625e7..71d0dd2c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentProdBp(arg(0), arg(1), gradients.get(0))); + return new SegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index 5de847162..a74aded65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentSumBp(arg(0), arg(1), gradients.get(0))); + return new SegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index d588ef4a8..b168adb43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -129,7 +129,7 @@ public class Cast extends BaseDynamicTransformOp { if(arg().dataType().isFPType()){ return Collections.singletonList(i_v.get(0).castTo(arg().dataType())); } else { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java index df5cdbcc7..9471c8bca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java @@ -75,7 +75,7 @@ public class RSqrt extends BaseTransformFloatOp { @Override public List doDiff(List i_v) { - SDVariable xPowNeg32 = f().pow(arg(), -1.5).mul(-0.5); + SDVariable xPowNeg32 = sameDiff.math.pow(arg(), -1.5).mul(-0.5); return Collections.singletonList(i_v.get(0).mul(xPowNeg32)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java index b00b29b75..fdbbafa99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -41,6 +42,10 @@ public class SELUDerivative extends BaseTransformStrictOp { private static final double SELU_ALPHA = 1.6732632423543772848170429916717; private static final double SELU_LAMBDA = 1.0507009873554804934193349852946; + public SELUDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } @@ -79,9 +84,8 @@ public class SELUDerivative extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(arg(),f().seluDerivative(arg())); - - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.div(arg(), new SELUDerivative(sameDiff, arg()).outputVariable()); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index 16afb4316..2a0d6021a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -84,7 +84,7 @@ public class TanhDerivative extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(sameDiff.onesLike(outputVariables()[0]), f().pow(f().cosh(arg()), 2)); + SDVariable ret = sameDiff.math.div(sameDiff.onesLike(outputVariables()[0]), sameDiff.math.pow(sameDiff.math.cosh(arg()), 2)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java index 672159a3e..4069967c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; import java.util.List; @@ -34,14 +36,18 @@ public class AddOp extends BaseDynamicTransformOp { public AddOp() { } - public AddOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public AddOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public AddOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public AddOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x,y}, null); + } + public AddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,7 +69,7 @@ public class AddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().addBp(larg(), rarg(), i_v.get(0)); + return new AddBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index b76942e95..2ce0101cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -33,14 +36,18 @@ public class DivOp extends BaseDynamicTransformOp { public DivOp() {} - public DivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public DivOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public DivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public DivOp( @NonNull INDArray x, INDArray y) { + this(new INDArray[]{x,y}, null); + } + public DivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -65,7 +72,7 @@ public class DivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java index c2314cdc4..408d86a75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import java.util.List; @@ -83,6 +84,6 @@ public class FModOp extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java index debfc5a5d..7ed2c6c1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; import java.util.List; @@ -39,6 +41,10 @@ public class FloorDivOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public FloorDivOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,6 +69,6 @@ public class FloorDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().floorDivBp(larg(), rarg(), i_v.get(0)); + return new FloorDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java index e7286816f..29799a221 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java @@ -16,12 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import org.nd4j.linalg.api.shape.Shape; import java.util.Collections; @@ -39,6 +41,10 @@ public class FloorModOp extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[]{x, y}, false); } + public FloorModOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorModOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -60,7 +66,7 @@ public class FloorModOp extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index 51f2e449d..0d634766e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -84,7 +84,7 @@ public class MergeAddOp extends BaseDynamicTransformOp { public List calculateOutputDataTypes(List dataTypes){ DataType first = dataTypes.get(0); for( int i=1; i doDiff(List i_v) { - return f().modBp(larg(), rarg(), i_v.get(0)); + return new ModBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java index 4636f9bc8..307a46557 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class MulOp extends BaseDynamicTransformOp { public MulOp() {} - public MulOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public MulOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public MulOp(INDArray first, INDArray second){ + this(first, second, null); } public MulOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public MulOp( INDArray[] inputs, INDArray[] outputs) { @@ -66,7 +72,7 @@ public class MulOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().mulBp(larg(), rarg(), i_v.get(0)); + return new MulBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java index d54d91dbc..9891464fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -34,14 +37,18 @@ public class RDivOp extends BaseDynamicTransformOp { public RDivOp() {} - public RDivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public RDivOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public RDivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public RDivOp(@NonNull INDArray x, @NonNull INDArray y){ + this(new INDArray[]{x, y}, null); + } + public RDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -64,6 +71,6 @@ public class RDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().rdivBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new RDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index 12c852949..5b233eb17 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; import java.util.List; @@ -45,8 +46,16 @@ public class RSubOp extends BaseDynamicTransformOp { this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace); } + public RSubOp(INDArray first, INDArray second){ + this(first, second, null); + } + public RSubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); + } + + public RSubOp( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } public RSubOp() {} @@ -61,13 +70,9 @@ public class RSubOp extends BaseDynamicTransformOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - public RSubOp( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); - } - @Override public List doDiff(List i_v) { - return f().rsubBp(larg(), rarg(), i_v.get(0)); + return new RSubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java index 04e72b5db..2fe1f150f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; import java.util.List; @@ -60,7 +61,7 @@ public class RealDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java index acbf840f1..e1ad183e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java @@ -36,14 +36,21 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { public SquaredDifferenceOp() {} - public SquaredDifferenceOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y, boolean inPlace) { + super(sameDiff, new SDVariable[]{x,y}, inPlace); } - public SquaredDifferenceOp(INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, x, y, false); } + public SquaredDifferenceOp(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x,y}, new INDArray[]{output}); + } + + public SquaredDifferenceOp(INDArray x, INDArray y) { + addInputArgument(new INDArray[]{x,y}); + } @Override public String opName() { @@ -63,8 +70,7 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v1) { - SDVariable[] outputs = new SquaredDifferenceBpOp(f().sameDiff(), new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputVariables(); - return Arrays.asList(outputs); + return new SquaredDifferenceBpOp(sameDiff, new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java index 0d222329e..da6e77a42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class SubOp extends BaseDynamicTransformOp { public SubOp() {} - public SubOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SubOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public SubOp(INDArray first, INDArray second){ + this(first, second, null); } public SubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public SubOp( INDArray[] inputs, INDArray[] outputs) { @@ -65,7 +71,7 @@ public class SubOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().subBp(larg(), rarg(), i_v.get(0)); + return new SubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java index 973ecd7ba..4c99b479b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java @@ -64,8 +64,8 @@ public class TruncateDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); + SDVariable gradWrtX = sameDiff.math.div(i_v.get(0),rarg()); + SDVariable gradWrtY = sameDiff.math.mul(sameDiff.math.neg(gradWrtX),sameDiff.math.div(larg(),rarg())); List ret = new ArrayList<>(2); ret.add(gradWrtX); ret.add(gradWrtY); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index 95bd0bf41..d70de3b3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -69,6 +69,6 @@ public class Not extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java index e5ea01a60..d1e5917c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMax extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java index bf6a37a55..b1ded6b55 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMin extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 4c6bf0ad9..ef21623c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -77,7 +77,7 @@ public class Abs extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sign(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sign(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java index f2d163f3f..bc86ae999 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java @@ -75,6 +75,6 @@ public class Ceil extends BaseTransformSameOp { public List doDiff(List f1) { //not continuously differentiable, but dOut/dIn = 0 in most places - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index 6422e8df8..b4550cb4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import java.util.Arrays; import java.util.List; @@ -77,6 +78,6 @@ public class Cube extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().cubeBp(arg(), f1.get(0))); + return new CubeBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java index db682174c..f5aec6b48 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java @@ -21,6 +21,8 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; import java.util.Collections; import java.util.List; @@ -56,9 +58,7 @@ public class Max extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + return new MaximumBp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java index 6585ace19..1560a0e80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java @@ -21,7 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -56,9 +58,10 @@ public class Min extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + //TODO optimize + SDVariable gt = arg(0).gt(arg(1)).castTo(arg(0).dataType()); + SDVariable lt = arg(0).lt(arg(1)).castTo(arg(1).dataType()); + return Arrays.asList(lt.mul(f1.get(0)), gt.mul(f1.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index 37b370fe9..f03805eb6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -73,7 +73,7 @@ public class Negative extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().neg(i_v.get(0))); + return Arrays.asList(sameDiff.math.neg(i_v.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java index 1e11fa34d..8d2049f25 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java @@ -74,7 +74,7 @@ public class Reciprocal extends BaseTransformSameOp { @Override public List doDiff(List i_v1) { // -1/(x^2) - SDVariable g = f().pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); + SDVariable g = sameDiff.math.pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index 375a8acb5..de9e0b685 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -75,6 +75,6 @@ public class Round extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java index c63e00114..010955baf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java @@ -21,8 +21,10 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -72,7 +74,7 @@ public class Square extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable g = f().powDerivative(arg(), 2).mul(i_v.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, 2).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 1506ac5f3..aeb543ea8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; import java.util.*; @@ -59,7 +60,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index 4338cf33d..e869a84eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 2f8aab0b1..fd0f5fd05 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 7afd75fac..12ec63222 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index 77474855c..9d7aceb96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import java.util.ArrayList; import java.util.Arrays; @@ -62,7 +63,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSqrtNBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 336c756ac..5e5cfd12e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -62,7 +63,7 @@ public class UnsortedSegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 3e0c60bb0..44dca1ed1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,9 +76,9 @@ public class ACos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dacos(x)/dx = -1 / sqrt(1-x^2) - SDVariable oneSubSq = f().square(arg()).rsub(1.0); - SDVariable sqrt = f().sqrt(oneSubSq); + SDVariable oneSubSq = sameDiff.math.square(arg()).rsub(1.0); + SDVariable sqrt = sameDiff.math.sqrt(oneSubSq); SDVariable ret = sqrt.rdiv(-1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 49ef2fb09..25f20e011 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -75,8 +75,8 @@ public class ASinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dasinh(x)/dx = 1 / sqrt(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); - SDVariable ret = i_v.get(0).div(f().sqrt(xSqPlus1)); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); + SDVariable ret = i_v.get(0).div(sameDiff.math.sqrt(xSqPlus1)); return Arrays.asList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index 483896dfd..a7a741759 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -76,8 +77,8 @@ public class ATan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(atan(x))/dx = 1/(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); SDVariable ret = xSqPlus1.rdiv(1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 21076ad6e..35ed040c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -64,7 +64,7 @@ public class Cos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().neg(f().sin(arg())).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.neg(sameDiff.math.sin(arg())).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index dc08ead5f..5144315da 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -74,7 +74,7 @@ public class Cosh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sinh(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sinh(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index c4fc245b7..cc3a5e116 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import java.util.Collections; import java.util.List; @@ -83,7 +84,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); + return new EluBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 21aa49522..f9288d1d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -73,7 +73,7 @@ public class Exp extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 538f6a003..5b093a3ec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -75,7 +75,7 @@ public class Expm1 extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index b784ddde0..009492924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -68,7 +68,7 @@ public class GELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); + SDVariable ret = new GELUDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index ddaa8631f..91c4eb8ae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative; import java.util.Collections; @@ -74,7 +75,7 @@ public class HardSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0))); + return new HardSigmoidBp(sameDiff, arg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index fa80bf880..fc44f3c22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import java.util.Arrays; import java.util.List; @@ -75,6 +76,6 @@ public class HardTanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0))); + return new HardTanhBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java index a937e1d63..1fd8ac430 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,7 @@ public class Log extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); - SDVariable toInverse = sameDiff.setupFunction(f().div(i_v.get(0), arg())); - return Arrays.asList(toInverse); + SDVariable toInverse = sameDiff.math.div(i_v.get(0), arg()); + return Collections.singletonList(toInverse); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java index 131986d15..d61504e39 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; @@ -73,7 +74,7 @@ public class Log1p extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, arg(), this); return Collections.singletonList(i_v.get(0).div(arg().add(1.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 353ced004..6a118a062 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.Collections; @@ -74,10 +75,8 @@ public class LogSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { -// SDVariable ret = f().logSigmoidDerivative(arg(), i_v.get(0)); -// return Arrays.asList(ret); - SDVariable sigmDeriv = f().sigmoidDerivative(arg(), i_v.get(0)).div(f().sigmoid(arg())); - return Collections.singletonList(sigmDeriv); + SDVariable v = new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputVariable().div(sameDiff.nn.sigmoid(arg())); + return Collections.singletonList(v); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java index 416f74133..05ca68aa0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -69,8 +70,8 @@ public class Mish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = new MishDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java index ab565b30f..e2d8c8b5a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java @@ -37,6 +37,10 @@ public class PreciseGELU extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public PreciseGELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false, true); + } + public PreciseGELU() { } @@ -72,7 +76,7 @@ public class PreciseGELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), true).mul(i_v.get(0)); + SDVariable ret = new PreciseGELUDerivative(sameDiff, arg(), false, true).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java index a05e34637..ecf85c9a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import java.util.Collections; import java.util.List; @@ -35,6 +36,10 @@ public class RationalTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RationalTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RationalTanh() {} public RationalTanh(INDArray x, INDArray z) { @@ -68,6 +73,6 @@ public class RationalTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0))); + return new RationalTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index d5fbf1294..8956f1b66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -42,6 +43,10 @@ public class RectifiedTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RectifiedTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RectifiedTanh() {} public RectifiedTanh(INDArray x, INDArray z) { @@ -85,6 +90,6 @@ public class RectifiedTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0))); + return new RectifiedTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index 00592f0e2..472c4eece 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import java.util.Arrays; import java.util.List; @@ -81,7 +82,7 @@ public class SELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().seluBp(arg(), i_v.get(0))); + return new SeluBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java index 37ef4b743..2452d5906 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.List; @@ -74,8 +75,7 @@ public class Sigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sigmoidDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java index 0fa918c11..bfdde52d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sin extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cos(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cos(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java index d5e3be988..2f5b981bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cosh(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cosh(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java index 11ffb2ef8..f3eeda670 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class SoftPlus extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dL/dIn = dL/Out * dOut/dIn - SDVariable ret = f().sigmoid(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.nn.sigmoid(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index 8be5ea2d4..057fda972 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -16,15 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import java.util.Arrays; import java.util.List; /** @@ -78,7 +75,7 @@ public class SoftSign extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().softsignBp(arg(), i_v.get(0))); + return new SoftSignBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index 0794e0b57..7f694f481 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -74,7 +74,7 @@ public class Swish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().swishDerivative(arg()).mul(i_v.get(0)); + SDVariable ret = new SwishDerivative(sameDiff, arg()).outputVariable().mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java index 350fb194e..552308859 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java @@ -39,8 +39,8 @@ public class SwishDerivative extends BaseTransformStrictOp { super(sameDiff, i_v1, i_v2, inPlace); } - public SwishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public SwishDerivative(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v, false); } public SwishDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java index 3244925b1..2c9a603ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java @@ -76,7 +76,7 @@ public class Tan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(tan(x))/dx = (sec(x))^2 = 1 / (cos(x))^2 - SDVariable cosx = f().cos(arg()); + SDVariable cosx = sameDiff.math.cos(arg()); SDVariable cosSqx = sameDiff.math().square(cosx); return Collections.singletonList(i_v.get(0).div(cosSqx)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java index 136d0bbea..ca45a549a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import java.util.Arrays; import java.util.List; @@ -74,7 +75,6 @@ public class Tanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().tanhDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new TanhDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index 64e6a96b1..cd2a2b540 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -34,12 +34,11 @@ public class NDBase { /** * Boolean and array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray all(INDArray x, int... dimensions) { - NDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); } @@ -47,12 +46,11 @@ public class NDBase { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray any(INDArray x, int... dimensions) { - NDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); } @@ -114,6 +112,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -138,6 +138,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -369,6 +371,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -472,6 +476,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -504,6 +510,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -602,6 +610,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -634,6 +644,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -760,6 +772,8 @@ public class NDBase { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -812,6 +826,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public INDArray merge(INDArray x, INDArray y) { + NDValidation.validateNumerical("merge", "x", x); + NDValidation.validateNumerical("merge", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y))[0]; + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -857,6 +886,8 @@ public class NDBase { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -919,6 +950,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1978,6 +2011,18 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public INDArray[] switchOp(INDArray x, INDArray predicate) { + NDValidation.validateBool("switchOp", "predicate", predicate); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(x, predicate)); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 03b9f8571..536633cd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -138,7 +138,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -161,7 +161,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index 8e8923834..1deddfd0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -50,13 +50,13 @@ public class NDMath { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { NDValidation.validateNumerical("EmbeddingLookup", "x", x); - NDValidation.validateInteger("EmbeddingLookup", "indices", indices); + NDValidation.validateNumerical("EmbeddingLookup", "indices", indices); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; } @@ -93,6 +93,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(x)); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, INDArray y) { + NDValidation.validateNumerical("add", "x", x); + NDValidation.validateNumerical("add", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(x, y))[0]; + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, double value) { + NDValidation.validateNumerical("add", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(x, value)); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -540,6 +569,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(x))[0]; } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, INDArray y) { + NDValidation.validateNumerical("div", "x", x); + NDValidation.validateNumerical("div", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(x, y))[0]; + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, double value) { + NDValidation.validateNumerical("div", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(x, value)); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -739,6 +797,52 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(x)); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorDiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorDiv", "x", x); + NDValidation.validateNumerical("floorDiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(x, y))[0]; + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorMod", "x", x); + NDValidation.validateNumerical("floorMod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(x, y))[0]; + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, double value) { + NDValidation.validateNumerical("floorMod", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(x, value)); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -1069,6 +1173,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(in))[0]; } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray max(INDArray x, INDArray y) { + NDValidation.validateNumerical("max", "x", x); + NDValidation.validateNumerical("max", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(x, y))[0]; + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -1120,6 +1241,40 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian)); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray min(INDArray x, INDArray y) { + NDValidation.validateNumerical("min", "x", x); + NDValidation.validateNumerical("min", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(x, y))[0]; + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mod(INDArray x, INDArray y) { + NDValidation.validateNumerical("mod", "x", x); + NDValidation.validateNumerical("mod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(x, y))[0]; + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -1132,6 +1287,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes)); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mul", "x", x); + NDValidation.validateNumerical("mul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(x, y))[0]; + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, double value) { + NDValidation.validateNumerical("mul", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(x, value)); + } + /** * Elementwise negative operation: out = -x
* @@ -1200,6 +1384,48 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(x, y))[0]; } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rationalTanh(INDArray x) { + NDValidation.validateNumerical("rationalTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(x)); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("rdiv", "x", x); + NDValidation.validateNumerical("rdiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(x, y))[0]; + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, double value) { + NDValidation.validateNumerical("rdiv", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(x, value)); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -1211,6 +1437,17 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(x)); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rectifiedTanh(INDArray x) { + NDValidation.validateNumerical("rectifiedTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(x)); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -1234,6 +1471,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(x)); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, INDArray y) { + NDValidation.validateNumerical("rsub", "x", x); + NDValidation.validateNumerical("rsub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(x, y))[0]; + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, double value) { + NDValidation.validateNumerical("rsub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(x, value)); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -1326,6 +1592,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Square(x)); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray squaredDifference(INDArray x, INDArray y) { + NDValidation.validateNumerical("squaredDifference", "x", x); + NDValidation.validateNumerical("squaredDifference", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(x, y))[0]; + } + /** * Standardize input variable along given axis
*


@@ -1364,6 +1647,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Step(x, value)); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, INDArray y) { + NDValidation.validateNumerical("sub", "x", x); + NDValidation.validateNumerical("sub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(x, y))[0]; + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, double value) { + NDValidation.validateNumerical("sub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(x, value)); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 06fb92b64..e2a8af245 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -355,6 +356,21 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double constant) { + NDValidation.validateNumerical("pad", "input", input); + NDValidation.validateNumerical("pad", "padding", padding); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant))[0]; + } + /** * Padding operation
* @@ -366,7 +382,20 @@ public class NDNN { public INDArray pad(INDArray input, INDArray padding, double constant) { NDValidation.validateNumerical("pad", "input", input); NDValidation.validateNumerical("pad", "padding", padding); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant))[0]; + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray preciseGelu(INDArray x) { + NDValidation.validateNumerical("preciseGelu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(x)); } /** diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index 6a3cc6eda..3e4367992 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -30,7 +30,6 @@ nd4j-tests-tensorflow - 1.8 1.8 @@ -216,8 +215,10 @@ **/*.java - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + - + nd4j-backends org.nd4j @@ -29,7 +30,6 @@ nd4j-tests - 1.8 1.8 @@ -179,7 +179,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -191,8 +192,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + com.google.code.findbugs + * + + + diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java new file mode 100644 index 000000000..f229069ae --- /dev/null +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java @@ -0,0 +1,82 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j; + +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; +import org.reflections.Reflections; +import org.reflections.scanners.MethodAnnotationsScanner; +import org.reflections.util.ClasspathHelper; +import org.reflections.util.ConfigurationBuilder; + +import java.lang.reflect.Method; +import java.util.*; + +import static org.junit.Assert.assertEquals; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public abstract class AbstractAssertTestsClass extends BaseND4JTest { + + protected abstract Set> getExclusions(); + + protected abstract String getPackageName(); + + protected abstract Class getBaseClass(); + + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + + @Test + public void checkTestClasses(){ + + Reflections reflections = new Reflections(new ConfigurationBuilder() + .setUrls(ClasspathHelper.forPackage(getPackageName())) + .setScanners(new MethodAnnotationsScanner())); + Set methods = reflections.getMethodsAnnotatedWith(Test.class); + Set> s = new HashSet<>(); + for(Method m : methods){ + s.add(m.getDeclaringClass()); + } + + List> l = new ArrayList<>(s); + Collections.sort(l, new Comparator>() { + @Override + public int compare(Class aClass, Class t1) { + return aClass.getName().compareTo(t1.getName()); + } + }); + + int count = 0; + for(Class c : l){ + if(!getBaseClass().isAssignableFrom(c) && !getExclusions().contains(c)){ + log.error("Test {} does not extend {} (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", + c, getBaseClass()); + count++; + } + } + assertEquals("Number of tests not extending BaseND4JTest", 0, count); + } +} diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml index 0ece4c8b0..e640ed219 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + 4.0.0 @@ -34,9 +35,9 @@ ${project.version} - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index 619af0d7b..3ba5a156a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -42,9 +42,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 4ed7c8b7b..7c2783904 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -49,9 +49,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index d7729f179..5537216ca 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -1,185 +1,185 @@ - 4.0.0 - jar + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 + jar - - org.nd4j - nd4j-remote - 1.0.0-SNAPSHOT - + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + - nd4j-json-server - nd4j-json-server + nd4j-json-server + nd4j-json-server - - UTF-8 - 1.7 - 1.7 - + + UTF-8 + 1.7 + 1.7 + - - - junit - junit - test - - - - org.nd4j - nd4j-json-client - ${project.version} - - - - org.slf4j - slf4j-api - - - - org.nd4j - nd4j-api - ${project.version} - - - - org.glassfish.jersey.core - jersey-client - ${jersey.version} - - - - org.glassfish.jersey.core - jersey-server - ${jersey.version} - - - - org.eclipse.jetty - jetty-server - 9.4.19.v20190610 - - - - org.eclipse.jetty - jetty-servlet - 9.4.19.v20190610 - - - - org.glassfish.jersey.inject - jersey-hk2 - ${jersey.version} - - - - org.glassfish.jersey.media - jersey-media-json-processing - ${jersey.version} - - - - org.glassfish.jersey.containers - jersey-container-servlet-core - ${jersey.version} - - - - ch.qos.logback - logback-core - ${logback.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - javax.xml.bind - jaxb-api - 2.3.0 - - - - com.sun.xml.bind - jaxb-impl - 2.3.0 - - - - com.sun.xml.bind - jaxb-core - 2.3.0 - - - - javax.activation - activation - 1.1 - - - - com.google.code.gson - gson - ${gson.version} - test - - - - - org.nd4j - nd4j-common-tests - ${project.version} - test - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${maven.compiler.source} - ${maven.compiler.target} - - - - - - - - nd4j-tests-cpu - + - org.nd4j - nd4j-native - ${project.version} - test + junit + junit + test - - - - nd4j-tests-cuda - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - test + org.nd4j + nd4j-json-client + ${project.version} - - - - testresources - - + + org.slf4j + slf4j-api + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + + org.eclipse.jetty + jetty-server + 9.4.19.v20190610 + + + + org.eclipse.jetty + jetty-servlet + 9.4.19.v20190610 + + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + org.glassfish.jersey.media + jersey-media-json-processing + ${jersey.version} + + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + activation + 1.1 + + + + com.google.code.gson + gson + ${gson.version} + test + + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.source} + ${maven.compiler.target} + + + + + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + test + + + + + + testresources + + diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index 827afb23a..c94bf86af 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -16,179 +16,186 @@ - 4.0.0 + 4.0.0 - org.nd4j - nd4j-aeron - jar - - nd4j-aeron - - org.nd4j - nd4j-serde - 1.0.0-SNAPSHOT - - - 1.8 - 1.8 - 1.5.4 - 1.4.0 - UTF-8 - + nd4j-aeron + jar - - - jdk9 - - 1.9 - - - 8 - - - - testresources - + nd4j-aeron - - nd4j-tests-cpu - - false - - - - org.nd4j - nd4j-native - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend - - - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g - - - - - + + + jdk9 + + 1.9 + + + 8 + + + + testresources + - - nd4j-tests-cuda - - false - - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin + + nd4j-tests-cpu + + false + - - org.apache.maven.surefire - surefire-junit47 - 2.19.1 - + + org.nd4j + nd4j-native + ${project.version} + - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend - - - -Ddtype=float -Xmx6g - - - - - - + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + + + + -Ddtype=float -Xmx6g + + + + + + - - - org.nd4j - nd4j-api - ${project.version} - - - io.aeron - aeron-all - ${aeron.version} - - - junit - junit - test - + + + org.nd4j + nd4j-api + ${project.version} + + + io.aeron + aeron-all + ${aeron.version} + + + junit + junit + test + - - ch.qos.logback - logback-classic - ${logback.version} - test - + + ch.qos.logback + logback-classic + ${logback.version} + test + - - ch.qos.logback - logback-core - ${logback.version} - test - + + ch.qos.logback + logback-core + ${logback.version} + test + - - org.nd4j - nd4j-common-tests - ${project.version} - test - - + + org.nd4j + nd4j-common-tests + ${project.version} + test + + diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 3a768c1a5..69879e965 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -88,7 +88,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -100,8 +101,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-camel-routes org.nd4j diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index f488bfde5..60de01b6e 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + nd4j-serde org.nd4j @@ -41,9 +42,9 @@ - junit - junit - test + junit + junit + test @@ -79,7 +80,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -91,8 +93,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-serde org.nd4j @@ -101,17 +102,17 @@ ${spark.version} provided - - com.google.code.findbugs - jsr305 - + + com.google.code.findbugs + jsr305 + - junit - junit - test + junit + junit + test @@ -147,7 +148,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -159,8 +161,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend +