From be70bea359aed74e9fa7b539867b11296086b09b Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Thu, 10 Oct 2019 10:43:32 +0300 Subject: [PATCH] Operator :: for INDArray Signed-off-by: Alexander Stoyakin --- nd4s/src/main/scala/org/nd4s/Implicits.scala | 20 +++++++++++++--- .../scala/org/nd4s/SliceableNDArray.scala | 8 +++++++ .../scala/org/nd4s/samediff/SameDiff.scala | 10 ++++++-- .../org/nd4s/NDArrayExtractionTest.scala | 23 +++++++++++++++++-- .../scala/org/nd4s/samediff/MathTest.scala | 12 +++++----- 5 files changed, 60 insertions(+), 13 deletions(-) diff --git a/nd4s/src/main/scala/org/nd4s/Implicits.scala b/nd4s/src/main/scala/org/nd4s/Implicits.scala index c1f440e0c..c1fa63e11 100644 --- a/nd4s/src/main/scala/org/nd4s/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/Implicits.scala @@ -311,6 +311,10 @@ object Implicits { def -> = IntRangeFrom(underlying) } + implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal { + def :: = IntRangeFromReverse(underlying) + } + implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange { protected[nd4s] override def asRange(max: => Int): DRange = DRange.from(underlying, max) @@ -381,17 +385,27 @@ object IndexNumberRange { val endExclusive = if (endR >= 0) endR + diff else max + endR + diff (start, endExclusive) } - NDArrayIndex.interval(start, step, end, false) } } -sealed trait IndexRange { +/*sealed*/ +trait IndexRange { def hasNegative: Boolean } case class IntRangeFrom(underlying: Int) extends IndexRange { - def apply[T](a: T): (Int, T) = (underlying, a) + def apply[T](a: T): (Int, T) = + (underlying, a) + + override def toString: String = s"$underlying->" + + override def hasNegative: Boolean = false +} + +case class IntRangeFromReverse(underlying: Int) extends IndexRange { + def apply[T](a: T): (T, Int) = + (a, underlying) override def toString: String = s"$underlying->" diff --git a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala index 9c5b05c21..87163639e 100644 --- a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala +++ b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala @@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory import _root_.scala.annotation.tailrec +package object ops { + case object :: extends IndexRange { + override def hasNegative: Boolean = false + } +} + trait SliceableNDArray[A <: INDArray] { lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]]) val underlying: A @@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] { @tailrec def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match { + case ops.:: :: t => + modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc) case -> :: t => modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc) case ---> :: t => diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index b71f3df94..e72ca37ed 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -54,9 +54,15 @@ class SameDiffWrapper { sd.placeHolder(name, dataType, shape: _*) } -case class SDIndexWrapper(start: Long) { +case class SDIndexWrapper(end: Long) { - def ->(end: Long): SDIndex = + def ::(start: Long): SDIndex = + SDIndex.interval(start, end) +} + +case class SDIndexWrapper1(start: Int) { + + def ::(end: Int): SDIndex = SDIndex.interval(start, end) } diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala index e19f4ccf7..95a86d550 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala @@ -17,6 +17,7 @@ package org.nd4s import org.nd4s.Implicits._ import org.scalatest.FlatSpec +import org.nd4s.ops.:: class NDArrayExtractionInCOrderingTest extends NDArrayExtractionTestBase with COrderingForTest class NDArrayExtractionInFortranOrderingTest extends NDArrayExtractionTestBase with FortranOrderingForTest @@ -48,7 +49,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => assert(extracted == expected) } - /*it should "be able to extract a part of 2d matrix with alternative syntax" in { + it should "be able to extract a part of 2d matrix with alternative syntax" in { val ndArray = Array( Array(1, 2, 3), @@ -64,7 +65,25 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => Array(7, 8) ).mkNDArray(ordering) assert(extracted == expected) - }*/ + } + + it should "be able to extract a part of 2d matrix with mixed syntax" in { + val ndArray = + Array( + Array(1, 2, 3), + Array(4, 5, 6), + Array(7, 8, 9) + ).mkNDArray(ordering) + + val extracted = ndArray(1 -> 3, 0 :: 2) + + val expected = + Array( + Array(4, 5), + Array(7, 8) + ).mkNDArray(ordering) + assert(extracted == expected) + } it should "be able to extract a part of 2d matrix with double data" in { val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index a650d19d1..dc41b31f6 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers { x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval - val slice2 = x(0 -> 2, ---).eval + val slice2 = x(0 :: 2, ---).eval slice1 shouldBe slice2 } @@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers { x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval - x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval - x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2, - 0 -> 1, - 0 -> 2).eval - x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval + x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, + 0 :: 1, + 0 :: 2).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval } }