diff --git a/nd4s/src/main/scala/org/nd4s/Implicits.scala b/nd4s/src/main/scala/org/nd4s/Implicits.scala index fc2b08b97..c1f440e0c 100644 --- a/nd4s/src/main/scala/org/nd4s/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/Implicits.scala @@ -276,6 +276,10 @@ object Implicits { override def hasNegative: Boolean = false } + case object --- extends IndexRange { + override def hasNegative: Boolean = false + } + implicit class IntRange(val underlying: Int) extends IndexNumberRange { protected[nd4s] override def asRange(max: => Int): DRange = DRange(underlying, underlying, true, 1, max) diff --git a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala index 565548e98..9c5b05c21 100644 --- a/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala +++ b/nd4s/src/main/scala/org/nd4s/SliceableNDArray.scala @@ -137,6 +137,9 @@ trait SliceableNDArray[A <: INDArray] { case ---> :: t => val ellipsised = List.fill(originalShape.length - i - t.size)(->) modifyTargetIndices(ellipsised ::: t, i, acc) + case --- :: t => + val ellipsised = List.fill(originalShape.length - i - t.size)(->) + modifyTargetIndices(ellipsised ::: t, i, acc) case IntRangeFrom(from: Int) :: t => val max = originalShape(i) modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc) diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index 801309c85..b71f3df94 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -54,9 +54,9 @@ class SameDiffWrapper { sd.placeHolder(name, dataType, shape: _*) } -case class SDIndexWrapper(end: Long) { +case class SDIndexWrapper(start: Long) { - def ::(start: Long): SDIndex = + def ->(end: Long): SDIndex = SDIndex.interval(start, end) } diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala index 004cb6703..1781d5b82 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala @@ -51,4 +51,12 @@ object Implicits { implicit def LongToPoint(x: Long): SDIndex = SDIndex.point(x) + + implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = { + val result = new SDIndexWrapper(start) + result + } + + implicit def IntToPoint(x: Int): SDIndex = + SDIndex.point(x) } diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala index 6d1f9c795..e19f4ccf7 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala @@ -48,6 +48,24 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => assert(extracted == expected) } + /*it should "be able to extract a part of 2d matrix with alternative syntax" in { + val ndArray = + Array( + Array(1, 2, 3), + Array(4, 5, 6), + Array(7, 8, 9) + ).mkNDArray(ordering) + + val extracted = ndArray(1 :: 3, 0 :: 2) + + val expected = + Array( + Array(4, 5), + Array(7, 8) + ).mkNDArray(ordering) + assert(extracted == expected) + }*/ + it should "be able to extract a part of 2d matrix with double data" in { val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) @@ -171,6 +189,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => val ellipsised = ndArray(--->) assert(ellipsised == ndArray) + + val ellipsised1 = ndArray(---) + assert(ellipsised1 == ndArray) } it should "accept partially ellipsis indices" in { diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index dc41b31f6..a650d19d1 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers { x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval - val slice2 = x(0 :: 2, ---).eval + val slice2 = x(0 -> 2, ---).eval slice1 shouldBe slice2 } @@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers { x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval - x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval - x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, - 0 :: 1, - 0 :: 2).eval - x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval + x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2, + 0 -> 1, + 0 -> 2).eval + x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval } }