Operator :: for INDArray

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-10-10 10:43:32 +03:00
parent 0de8523ebd
commit be70bea359
5 changed files with 60 additions and 13 deletions

View File

@ -311,6 +311,10 @@ object Implicits {
def -> = IntRangeFrom(underlying) def -> = IntRangeFrom(underlying)
} }
implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal {
def :: = IntRangeFromReverse(underlying)
}
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange { implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
protected[nd4s] override def asRange(max: => Int): DRange = protected[nd4s] override def asRange(max: => Int): DRange =
DRange.from(underlying, max) DRange.from(underlying, max)
@ -381,17 +385,27 @@ object IndexNumberRange {
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
(start, endExclusive) (start, endExclusive)
} }
NDArrayIndex.interval(start, step, end, false) NDArrayIndex.interval(start, step, end, false)
} }
} }
sealed trait IndexRange { /*sealed*/
trait IndexRange {
def hasNegative: Boolean def hasNegative: Boolean
} }
case class IntRangeFrom(underlying: Int) extends IndexRange { case class IntRangeFrom(underlying: Int) extends IndexRange {
def apply[T](a: T): (Int, T) = (underlying, a) def apply[T](a: T): (Int, T) =
(underlying, a)
override def toString: String = s"$underlying->"
override def hasNegative: Boolean = false
}
case class IntRangeFromReverse(underlying: Int) extends IndexRange {
def apply[T](a: T): (T, Int) =
(a, underlying)
override def toString: String = s"$underlying->" override def toString: String = s"$underlying->"

View File

@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory
import _root_.scala.annotation.tailrec import _root_.scala.annotation.tailrec
package object ops {
case object :: extends IndexRange {
override def hasNegative: Boolean = false
}
}
trait SliceableNDArray[A <: INDArray] { trait SliceableNDArray[A <: INDArray] {
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]]) lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
val underlying: A val underlying: A
@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] {
@tailrec @tailrec
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match { def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
case ops.:: :: t =>
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
case -> :: t => case -> :: t =>
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc) modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
case ---> :: t => case ---> :: t =>

View File

@ -54,9 +54,15 @@ class SameDiffWrapper {
sd.placeHolder(name, dataType, shape: _*) sd.placeHolder(name, dataType, shape: _*)
} }
case class SDIndexWrapper(start: Long) { case class SDIndexWrapper(end: Long) {
def ->(end: Long): SDIndex = def ::(start: Long): SDIndex =
SDIndex.interval(start, end)
}
case class SDIndexWrapper1(start: Int) {
def ::(end: Int): SDIndex =
SDIndex.interval(start, end) SDIndex.interval(start, end)
} }

View File

@ -17,6 +17,7 @@ package org.nd4s
import org.nd4s.Implicits._ import org.nd4s.Implicits._
import org.scalatest.FlatSpec import org.scalatest.FlatSpec
import org.nd4s.ops.::
class NDArrayExtractionInCOrderingTest extends NDArrayExtractionTestBase with COrderingForTest class NDArrayExtractionInCOrderingTest extends NDArrayExtractionTestBase with COrderingForTest
class NDArrayExtractionInFortranOrderingTest extends NDArrayExtractionTestBase with FortranOrderingForTest class NDArrayExtractionInFortranOrderingTest extends NDArrayExtractionTestBase with FortranOrderingForTest
@ -48,7 +49,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
assert(extracted == expected) assert(extracted == expected)
} }
/*it should "be able to extract a part of 2d matrix with alternative syntax" in { it should "be able to extract a part of 2d matrix with alternative syntax" in {
val ndArray = val ndArray =
Array( Array(
Array(1, 2, 3), Array(1, 2, 3),
@ -64,7 +65,25 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
Array(7, 8) Array(7, 8)
).mkNDArray(ordering) ).mkNDArray(ordering)
assert(extracted == expected) assert(extracted == expected)
}*/ }
it should "be able to extract a part of 2d matrix with mixed syntax" in {
val ndArray =
Array(
Array(1, 2, 3),
Array(4, 5, 6),
Array(7, 8, 9)
).mkNDArray(ordering)
val extracted = ndArray(1 -> 3, 0 :: 2)
val expected =
Array(
Array(4, 5),
Array(7, 8)
).mkNDArray(ordering)
assert(extracted == expected)
}
it should "be able to extract a part of 2d matrix with double data" in { it should "be able to extract a part of 2d matrix with double data" in {
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)

View File

@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers {
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
val slice2 = x(0 -> 2, ---).eval val slice2 = x(0 :: 2, ---).eval
slice1 shouldBe slice2 slice1 shouldBe slice2
} }
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2, x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
0 -> 1, 0 :: 1,
0 -> 2).eval 0 :: 2).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
} }
} }