diff --git a/nd4s/src/main/scala/org/nd4s/Implicits.scala b/nd4s/src/main/scala/org/nd4s/Implicits.scala index fc2b08b97..c1fa63e11 100644 --- a/nd4s/src/main/scala/org/nd4s/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/Implicits.scala @@ -276,6 +276,10 @@ object Implicits { override def hasNegative: Boolean = false } + case object --- extends IndexRange { + override def hasNegative: Boolean = false + } + implicit class IntRange(val underlying: Int) extends IndexNumberRange { protected[nd4s] override def asRange(max: => Int): DRange = DRange(underlying, underlying, true, 1, max) @@ -307,6 +311,10 @@ object Implicits { def -> = IntRangeFrom(underlying) } + implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal { + def :: = IntRangeFromReverse(underlying) + } + implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange { protected[nd4s] override def asRange(max: => Int): DRange = DRange.from(underlying, max) @@ -377,17 +385,27 @@ object IndexNumberRange { val endExclusive = if (endR >= 0) endR + diff else max + endR + diff (start, endExclusive) } - NDArrayIndex.interval(start, step, end, false) } } -sealed trait IndexRange { +/*sealed*/ +trait IndexRange { def hasNegative: Boolean } case class IntRangeFrom(underlying: Int) extends IndexRange { - def apply[T](a: T): (Int, T) = (underlying, a) + def apply[T](a: T): (Int, T) = + (underlying, a) + + override def toString: String = s"$underlying->" + + override def hasNegative: Boolean = false +} + +case class IntRangeFromReverse(underlying: Int) extends IndexRange { + def apply[T](a: T): (T, Int) = + (a, underlying) override def toString: String = s"$underlying->" diff --git a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala index 565548e98..87163639e 100644 --- a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala +++ b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala @@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory import _root_.scala.annotation.tailrec +package object ops { + case object :: extends IndexRange { + override def hasNegative: Boolean = false + } +} + trait SliceableNDArray[A <: INDArray] { lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]]) val underlying: A @@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] { @tailrec def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match { + case ops.:: :: t => + modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc) case -> :: t => modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc) case ---> :: t => @@ -137,6 +145,9 @@ trait SliceableNDArray[A <: INDArray] { case ---> :: t => val ellipsised = List.fill(originalShape.length - i - t.size)(->) modifyTargetIndices(ellipsised ::: t, i, acc) + case --- :: t => + val ellipsised = List.fill(originalShape.length - i - t.size)(->) + modifyTargetIndices(ellipsised ::: t, i, acc) case IntRangeFrom(from: Int) :: t => val max = originalShape(i) modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index 801309c85..e72ca37ed 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -60,6 +60,12 @@ case class SDIndexWrapper(end: Long) { SDIndex.interval(start, end) } +case class SDIndexWrapper1(start: Int) { + + def ::(end: Int): SDIndex = + SDIndex.interval(start, end) +} + object --- extends SDIndex { val thisIndex: SDIndex = SDIndex.all() } diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala index 004cb6703..1781d5b82 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala @@ -51,4 +51,12 @@ object Implicits { implicit def LongToPoint(x: Long): SDIndex = SDIndex.point(x) + + implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = { + val result = new SDIndexWrapper(start) + result + } + + implicit def IntToPoint(x: Int): SDIndex = + SDIndex.point(x) } diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala index 6d1f9c795..b84270817 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala @@ -48,6 +48,42 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => assert(extracted == expected) } + it should "be able to extract a part of 2d matrix with alternative syntax" in { + val ndArray = + Array( + Array(1, 2, 3), + Array(4, 5, 6), + Array(7, 8, 9) + ).mkNDArray(ordering) + + val extracted = ndArray(1 :: 3, 0 :: 2) + + val expected = + Array( + Array(4, 5), + Array(7, 8) + ).mkNDArray(ordering) + assert(extracted == expected) + } + + it should "be able to extract a part of 2d matrix with mixed syntax" in { + val ndArray = + Array( + Array(1, 2, 3), + Array(4, 5, 6), + Array(7, 8, 9) + ).mkNDArray(ordering) + + val extracted = ndArray(1 -> 3, 0 :: 2) + + val expected = + Array( + Array(4, 5), + Array(7, 8) + ).mkNDArray(ordering) + assert(extracted == expected) + } + it should "be able to extract a part of 2d matrix with double data" in { val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) @@ -171,6 +207,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => val ellipsised = ndArray(--->) assert(ellipsised == ndArray) + + val ellipsised1 = ndArray(---) + assert(ellipsised1 == ndArray) } it should "accept partially ellipsis indices" in {