[WIP] nd4s tests coverage (#59)
* Unit tests added Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added operator + for left integer Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Added broadcast tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Build fixed after master changes Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Operatable tested Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * -sAdded tests * Projection tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Projection tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * Benchmarking Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
c3e684d648
commit
2fb4a52a02
|
@ -18,7 +18,7 @@ package org.nd4s
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
||||||
|
|
||||||
class ColumnProjectedNDArray(val array: INDArray, filtered: Array[Int]) {
|
class ColumnProjectedNDArray(val array: INDArray, val filtered: Array[Int]) {
|
||||||
def this(ndarray: INDArray) {
|
def this(ndarray: INDArray) {
|
||||||
this(ndarray, (0 until ndarray.columns()).toArray)
|
this(ndarray, (0 until ndarray.columns()).toArray)
|
||||||
}
|
}
|
||||||
|
|
|
@ -214,27 +214,50 @@ object Implicits {
|
||||||
def toScalar: INDArray = Nd4j.scalar(ev.toDouble(underlying))
|
def toScalar: INDArray = Nd4j.scalar(ev.toDouble(underlying))
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
|
// TODO: move ops to single trait
|
||||||
implicit class Float2Scalar(val underlying: Float) {
|
implicit class Float2Scalar(val underlying: Float) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
|
def /(x: INDArray) = underlying.toScalar / x
|
||||||
|
def \(x: INDArray) = underlying.toScalar \ x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Double2Scalar(val underlying: Double) {
|
implicit class Double2Scalar(val underlying: Double) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
|
def /(x: INDArray) = underlying.toScalar / x
|
||||||
|
def \(x: INDArray) = underlying.toScalar \ x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Long2Scalar(val underlying: Long) {
|
implicit class Long2Scalar(val underlying: Long) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
|
def /(x: INDArray) = underlying.toScalar / x
|
||||||
|
def \(x: INDArray) = underlying.toScalar \ x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Int2Scalar(val underlying: Int) {
|
implicit class Int2Scalar(val underlying: Int) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
|
def /(x: INDArray) = underlying.toScalar / x
|
||||||
|
def \(x: INDArray) = underlying.toScalar \ x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Byte2Scalar(val underlying: Byte) {
|
implicit class Byte2Scalar(val underlying: Byte) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
|
def /(x: INDArray) = underlying.toScalar / x
|
||||||
|
def \(x: INDArray) = underlying.toScalar \ x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
implicit class Boolean2Scalar(val underlying: Boolean) {
|
implicit class Boolean2Scalar(val underlying: Boolean) {
|
||||||
|
def +(x: INDArray) = underlying.toScalar + x
|
||||||
|
def *(x: INDArray) = underlying.toScalar * x
|
||||||
def toScalar: INDArray = Nd4j.scalar(underlying)
|
def toScalar: INDArray = Nd4j.scalar(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4s
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
||||||
|
|
||||||
class RowProjectedNDArray(val array: INDArray, filtered: Array[Int]) {
|
class RowProjectedNDArray(val array: INDArray, val filtered: Array[Int]) {
|
||||||
def this(ndarray: INDArray) {
|
def this(ndarray: INDArray) {
|
||||||
this(ndarray, (0 until ndarray.rows()).toArray)
|
this(ndarray, (0 until ndarray.rows()).toArray)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.nd4s
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
import org.nd4j.linalg.indexing.{ NDArrayIndex, SpecifiedIndex }
|
||||||
|
|
||||||
class SliceProjectedNDArray(val array: INDArray, filtered: Array[Int]) {
|
class SliceProjectedNDArray(val array: INDArray, val filtered: Array[Int]) {
|
||||||
def this(ndarray: INDArray) {
|
def this(ndarray: INDArray) {
|
||||||
this(ndarray, (0 until ndarray.slices().toInt).toArray)
|
this(ndarray, (0 until ndarray.slices().toInt).toArray)
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ class FunctionalOpExecutioner extends OpExecutioner {
|
||||||
case DataType.FLOAT => op.op(op.x.getFloat(i.toLong))
|
case DataType.FLOAT => op.op(op.x.getFloat(i.toLong))
|
||||||
case DataType.INT => op.op(op.x.getInt(i))
|
case DataType.INT => op.op(op.x.getInt(i))
|
||||||
case DataType.SHORT => op.op(op.x.getInt(i))
|
case DataType.SHORT => op.op(op.x.getInt(i))
|
||||||
case (DataType.LONG) => op.op(op.x.getLong(i.toLong))
|
case DataType.LONG => op.op(op.x.getLong(i.toLong))
|
||||||
}
|
}
|
||||||
retVal.putScalar(i, filtered)
|
retVal.putScalar(i, filtered)
|
||||||
}
|
}
|
||||||
|
@ -466,6 +466,12 @@ class FunctionalOpExecutioner extends OpExecutioner {
|
||||||
|
|
||||||
def createConstantBuffer(values: Array[Double], desiredType: DataType): DataBuffer = ???
|
def createConstantBuffer(values: Array[Double], desiredType: DataType): DataBuffer = ???
|
||||||
|
|
||||||
|
def runFullBenchmarkSuit(x: Boolean): String =
|
||||||
|
Nd4j.getExecutioner.runFullBenchmarkSuit(x)
|
||||||
|
|
||||||
|
def runLightBenchmarkSuit(x: Boolean): String =
|
||||||
|
Nd4j.getExecutioner.runLightBenchmarkSuit(x)
|
||||||
|
|
||||||
@deprecated def scatterUpdate(op: ScatterUpdate.UpdateOp,
|
@deprecated def scatterUpdate(op: ScatterUpdate.UpdateOp,
|
||||||
array: INDArray,
|
array: INDArray,
|
||||||
indices: INDArray,
|
indices: INDArray,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s
|
package org.nd4s
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4s.Implicits._
|
import org.nd4s.Implicits._
|
||||||
import org.scalatest.{ FlatSpec, Matchers }
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
|
||||||
|
@ -145,4 +146,44 @@ class NDArrayCollectionAPITest extends FlatSpec with Matchers {
|
||||||
//check if any element in nd meet the criteria.
|
//check if any element in nd meet the criteria.
|
||||||
assert(ndArray.exists(_ > 8))
|
assert(ndArray.exists(_ > 8))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "provides existTyped API" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2, 3),
|
||||||
|
Array(4, 5, 6),
|
||||||
|
Array(7, 8, 9)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
//check if any element in nd meet the criteria.
|
||||||
|
assert(ndArray.existsTyped[Int](_ > 8)(IntNDArrayEvidence))
|
||||||
|
}
|
||||||
|
|
||||||
|
"CollectionLikeNDArray" should "provides forAll API" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2, 3),
|
||||||
|
Array(4, 5, 6),
|
||||||
|
Array(7, 8, 9)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val resultFalse = ndArray.forall(_ > 3)
|
||||||
|
assert(false == resultFalse)
|
||||||
|
|
||||||
|
val resultTrue = ndArray.forall(_ < 10)
|
||||||
|
assert(true == resultTrue)
|
||||||
|
}
|
||||||
|
|
||||||
|
"CollectionLikeNDArray" should "provides forAllTyped API" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2, 3),
|
||||||
|
Array(4, 5, 6),
|
||||||
|
Array(7, 8, 9)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val results = ndArray.forallTyped[Int](_ > 3)(IntNDArrayEvidence)
|
||||||
|
assert(false == results)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.nd4s
|
package org.nd4s
|
||||||
|
|
||||||
import org.scalatest.FlatSpec
|
|
||||||
import org.nd4s.Implicits._
|
import org.nd4s.Implicits._
|
||||||
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
|
||||||
class NDArrayProjectionAPITest extends FlatSpec {
|
class NDArrayProjectionAPITest extends FlatSpec {
|
||||||
"ColumnProjectedNDArray" should "map column correctly" in {
|
"ColumnProjectedNDArray" should "map column correctly" in {
|
||||||
|
@ -41,7 +41,109 @@ class NDArrayProjectionAPITest extends FlatSpec {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
"RowProjectedNDArray" should "map row correctly" in {
|
"ColumnProjectedNDArray" should "map column correctly 2" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.columnP map (input => input + 1)
|
||||||
|
assert(
|
||||||
|
result == Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"ColumnProjectedNDArray" should "map column correctly 3" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.columnP flatMap (input => input + 1)
|
||||||
|
assert(
|
||||||
|
result == Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"ColumnProjectedNDArray" should "map column correctly in place " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
ndArray.columnP flatMapi (input => input + 1)
|
||||||
|
assert(
|
||||||
|
ndArray == Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"ColumnProjectedNDArray" should "map column correctly 4" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.columnP map (input => input + 1)
|
||||||
|
assert(
|
||||||
|
result == Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"ColumnProjectedNDArray" should "map column correctly 5" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
ndArray.columnP mapi (input => input + 1)
|
||||||
|
assert(
|
||||||
|
ndArray == Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"ColumnProjectedNDArray" should "flatmap column correctly" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.columnP withFilter (input => false)
|
||||||
|
assert(result.filtered.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "map row correctly in for loop " in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1d, 2d, 3d),
|
Array(1d, 2d, 3d),
|
||||||
|
@ -60,6 +162,104 @@ class NDArrayProjectionAPITest extends FlatSpec {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "map row correctly " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.rowP map (input => input / 2)
|
||||||
|
|
||||||
|
assert(
|
||||||
|
result ==
|
||||||
|
Array[Double](0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 4.5000).toNDArray.reshape(3, 3)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "filter rows correctly " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.rowP withFilter (input => false)
|
||||||
|
assert(result.filtered.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "flatMap rows correctly " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.rowP flatMap (input => input + 1)
|
||||||
|
val expected =
|
||||||
|
Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
assert(result == expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "map row correctly 2 " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
val result = ndArray.rowP map (input => input / 2)
|
||||||
|
|
||||||
|
assert(
|
||||||
|
result ==
|
||||||
|
Array[Double](0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 4.5000).toNDArray.reshape(3, 3)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "flatMap in place rows correctly " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
ndArray.rowP flatMapi (input => input + 1)
|
||||||
|
val expected =
|
||||||
|
Array(
|
||||||
|
Array(2d, 3d, 4d),
|
||||||
|
Array(5d, 6d, 7d),
|
||||||
|
Array(8d, 9d, 10d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
assert(ndArray == expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
"RowProjectedNDArray" should "map in place rows correctly " in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1d, 2d, 3d),
|
||||||
|
Array(4d, 5d, 6d),
|
||||||
|
Array(7d, 8d, 9d)
|
||||||
|
).toNDArray
|
||||||
|
|
||||||
|
ndArray.rowP mapi (input => input / 2)
|
||||||
|
|
||||||
|
assert(
|
||||||
|
ndArray ==
|
||||||
|
Array[Double](0.5000, 1.0000, 1.5000, 2.0000, 2.5000, 3.0000, 3.5000, 4.0000, 4.5000).toNDArray.reshape(3, 3)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
"SliceProjectedNDArray" should "map slice correctly" in {
|
"SliceProjectedNDArray" should "map slice correctly" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
(1d to 8d by 1).asNDArray(2, 2, 2)
|
(1d to 8d by 1).asNDArray(2, 2, 2)
|
||||||
|
@ -71,4 +271,40 @@ class NDArrayProjectionAPITest extends FlatSpec {
|
||||||
|
|
||||||
assert(result == List(25d, 36d, 49d, 64d).asNDArray(1, 2, 2))
|
assert(result == List(25d, 36d, 49d, 64d).asNDArray(1, 2, 2))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"SliceProjectedNDArray" should "flatmap slice correctly" in {
|
||||||
|
val ndArray =
|
||||||
|
(1d to 8d by 1).asNDArray(2, 2, 2)
|
||||||
|
|
||||||
|
val result = ndArray.sliceP flatMap (input => input * 2)
|
||||||
|
val expected =
|
||||||
|
(2d to 16d by 2).asNDArray(2, 2, 2)
|
||||||
|
assert(result == expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
"SliceProjectedNDArray" should "flatmap slice correctly in place" in {
|
||||||
|
val ndArray =
|
||||||
|
(1d to 8d by 1).asNDArray(2, 2, 2)
|
||||||
|
|
||||||
|
ndArray.sliceP flatMapi (input => input * 2)
|
||||||
|
val expected =
|
||||||
|
(2d to 16d by 2).asNDArray(2, 2, 2)
|
||||||
|
assert(ndArray == expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
"SliceProjectedNDArray" should "map slice correctly in place" in {
|
||||||
|
val ndArray =
|
||||||
|
(1d to 8d by 1).asNDArray(2, 2, 2)
|
||||||
|
|
||||||
|
ndArray.sliceP mapi (input => input * 2)
|
||||||
|
val expected =
|
||||||
|
(2d to 16d by 2).asNDArray(2, 2, 2)
|
||||||
|
assert(ndArray == expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
"SliceProjectedNDArray" should "filter slice correctly" in {
|
||||||
|
val ndArray = (1d until 10d by 1).asNDArray(2, 2, 2)
|
||||||
|
val result = ndArray.sliceP withFilter (input => false)
|
||||||
|
assert(result.filtered.isEmpty)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
package org.nd4s
|
package org.nd4s
|
||||||
|
|
||||||
import org.junit.runner.RunWith
|
import org.junit.runner.RunWith
|
||||||
import org.nd4s.Implicits._
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
|
import org.nd4s.Implicits._
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import org.scalatest.junit.JUnitRunner
|
import org.scalatest.junit.JUnitRunner
|
||||||
import org.scalatest.{ FlatSpec, Matchers }
|
import org.scalatest.{ FlatSpec, Matchers }
|
||||||
|
@ -146,4 +146,136 @@ class OperatableNDArrayTest extends FlatSpec with Matchers {
|
||||||
val sumValueInFloatImplicit = ndArray.sumT
|
val sumValueInFloatImplicit = ndArray.sumT
|
||||||
sumValueInFloatImplicit shouldBe a[java.lang.Float]
|
sumValueInFloatImplicit shouldBe a[java.lang.Float]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
it should "provide matrix multiplicaton operations " in {
|
||||||
|
val a = Nd4j.create(Array[Float](4, 6, 5, 7)).reshape(2, 2)
|
||||||
|
val b = Nd4j.create(Array[Float](1, 3, 4, 8)).reshape(2, 2)
|
||||||
|
a **= b
|
||||||
|
val expected = Array[Float](28.0000f, 60.0000f, 33.0000f, 71.0000f).toNDArray.reshape(2, 2)
|
||||||
|
a shouldBe expected
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "provide matrix division operations " in {
|
||||||
|
val a = Nd4j.create(Array[Float](4, 6, 5, 7)).reshape(2, 2)
|
||||||
|
a /= 12
|
||||||
|
a.get(0) shouldBe (0.3333 +- 0.0001)
|
||||||
|
a.get(1) shouldBe (0.5 +- 0.0001)
|
||||||
|
a.get(2) shouldBe (0.4167 +- 0.0001)
|
||||||
|
a.get(3) shouldBe (0.5833 +- 0.0001)
|
||||||
|
|
||||||
|
val b = Nd4j.create(Array[Float](4, 6, 5, 7)).reshape(2, 2)
|
||||||
|
b %= 12
|
||||||
|
b.get(0) shouldBe (4.0)
|
||||||
|
b.get(1) shouldBe (6.0)
|
||||||
|
b.get(2) shouldBe (5.0)
|
||||||
|
b.get(3) shouldBe (-5.0)
|
||||||
|
|
||||||
|
val c = Nd4j.create(Array[Float](4, 6, 5, 7)).reshape(2, 2)
|
||||||
|
c \= 12
|
||||||
|
c.get(0) shouldBe (3.0)
|
||||||
|
c.get(1) shouldBe (2.0)
|
||||||
|
c.get(2) shouldBe (2.4000 +- 0.0001)
|
||||||
|
c.get(3) shouldBe (1.7143 +- 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "provide math operations for vectors " in {
|
||||||
|
val a = Nd4j.create(Array[Float](4, 6))
|
||||||
|
val b = Nd4j.create(Array[Float](1, 3))
|
||||||
|
a /= b
|
||||||
|
val expected1 = Nd4j.create(Array[Float](4, 2))
|
||||||
|
assert(a == expected1)
|
||||||
|
|
||||||
|
a *= b
|
||||||
|
val expected2 = Nd4j.create(Array[Float](4, 6))
|
||||||
|
assert(a == expected2)
|
||||||
|
|
||||||
|
a += b
|
||||||
|
val expected3 = Nd4j.create(Array[Float](5, 9))
|
||||||
|
assert(a == expected3)
|
||||||
|
|
||||||
|
a -= b
|
||||||
|
val expected4 = Nd4j.create(Array[Float](4, 6))
|
||||||
|
assert(a == expected4)
|
||||||
|
|
||||||
|
a \= b
|
||||||
|
val expected5 = Array[Float](0.25f, 0.5f).toNDArray
|
||||||
|
assert(a == expected5)
|
||||||
|
|
||||||
|
val c = a * b
|
||||||
|
val expected6 = Array[Float](0.25f, 1.5f).toNDArray
|
||||||
|
assert(c == expected6)
|
||||||
|
|
||||||
|
val d = a + b
|
||||||
|
val expected7 = Array[Float](1.25f, 3.5f).toNDArray
|
||||||
|
assert(d == expected7)
|
||||||
|
|
||||||
|
val e = a / b
|
||||||
|
e.get(0) should be(0.2500 +- 0.0001)
|
||||||
|
e.get(1) should be(0.1667 +- 0.0001)
|
||||||
|
|
||||||
|
val f = a \ b
|
||||||
|
f.get(0) should be(4.0 +- 0.0001)
|
||||||
|
f.get(1) should be(6.0 +- 0.0001)
|
||||||
|
|
||||||
|
val g = a ** b
|
||||||
|
g.get(0) shouldBe 1.7500
|
||||||
|
|
||||||
|
val h = a dot b
|
||||||
|
g.get(0) shouldBe 1.7500
|
||||||
|
|
||||||
|
d.sumT shouldBe 4.75
|
||||||
|
|
||||||
|
d.meanT shouldBe 2.375
|
||||||
|
|
||||||
|
d.norm1T shouldBe 4.75
|
||||||
|
|
||||||
|
d.maxT shouldBe 3.5
|
||||||
|
|
||||||
|
d.minT shouldBe 1.25
|
||||||
|
|
||||||
|
d.prodT shouldBe 4.375
|
||||||
|
|
||||||
|
d.varT shouldBe 2.53125
|
||||||
|
|
||||||
|
d.norm2T should be(3.7165 +- 0.0001)
|
||||||
|
|
||||||
|
d.stdT should be(1.5909 +- 0.0001)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "provide arithmetic ops calls on integers " in {
|
||||||
|
val ndArray = Array(1, 2).toNDArray
|
||||||
|
val c = ndArray + 5
|
||||||
|
c shouldBe Array(6, 7).toNDArray
|
||||||
|
|
||||||
|
val d = 5 + ndArray
|
||||||
|
c shouldBe Array(6, 7).toNDArray
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "broadcast add ops calls on vectors with different length " in {
|
||||||
|
val x = Array(1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f).mkNDArray(Array(3, 5))
|
||||||
|
val y = Array[Float](1f, 1f, 1f, 1f, 1f).toNDArray
|
||||||
|
val e = x + 1f.toScalar
|
||||||
|
assert((x + y) == e)
|
||||||
|
|
||||||
|
val x1 = Array(1f, 1f, 1f, 1f, 1f, 1f).mkNDArray(Array(3, 1, 2))
|
||||||
|
val y1 = Array[Float](1f, 1f, 1f, 1f).toNDArray.reshape(2, 2)
|
||||||
|
val t1 = Array(1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f).mkNDArray(Array(3, 2, 2))
|
||||||
|
val e1 = t1 + 1f
|
||||||
|
assert((x1 + y1) == e1)
|
||||||
|
|
||||||
|
val e2 = 1f + t1
|
||||||
|
assert(e1 == e2)
|
||||||
|
}
|
||||||
|
|
||||||
|
it should "broadcast multiplication ops " in {
|
||||||
|
|
||||||
|
val x1 = Array(1f, 1f, 1f, 1f, 1f, 1f).mkNDArray(Array(3, 1, 2))
|
||||||
|
val y1 = Array[Float](1f, 1f, 1f, 1f).toNDArray.reshape(2, 2)
|
||||||
|
val t1 = Array(1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f).mkNDArray(Array(3, 2, 2))
|
||||||
|
val e1 = t1 * 1f
|
||||||
|
assert((x1 * y1) == e1)
|
||||||
|
|
||||||
|
val e2 = 1f * t1
|
||||||
|
assert(e1 == e2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue