Indexing syntax changed

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Alexander Stoyakin 2019-10-09 16:07:20 +03:00
parent f515151f5f
commit 0de8523ebd
6 changed files with 44 additions and 8 deletions

View File

@ -276,6 +276,10 @@ object Implicits {
override def hasNegative: Boolean = false override def hasNegative: Boolean = false
} }
case object --- extends IndexRange {
override def hasNegative: Boolean = false
}
implicit class IntRange(val underlying: Int) extends IndexNumberRange { implicit class IntRange(val underlying: Int) extends IndexNumberRange {
protected[nd4s] override def asRange(max: => Int): DRange = protected[nd4s] override def asRange(max: => Int): DRange =
DRange(underlying, underlying, true, 1, max) DRange(underlying, underlying, true, 1, max)

View File

@ -137,6 +137,9 @@ trait SliceableNDArray[A <: INDArray] {
case ---> :: t => case ---> :: t =>
val ellipsised = List.fill(originalShape.length - i - t.size)(->) val ellipsised = List.fill(originalShape.length - i - t.size)(->)
modifyTargetIndices(ellipsised ::: t, i, acc) modifyTargetIndices(ellipsised ::: t, i, acc)
case --- :: t =>
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
modifyTargetIndices(ellipsised ::: t, i, acc)
case IntRangeFrom(from: Int) :: t => case IntRangeFrom(from: Int) :: t =>
val max = originalShape(i) val max = originalShape(i)
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc) modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)

View File

@ -54,9 +54,9 @@ class SameDiffWrapper {
sd.placeHolder(name, dataType, shape: _*) sd.placeHolder(name, dataType, shape: _*)
} }
case class SDIndexWrapper(end: Long) { case class SDIndexWrapper(start: Long) {
def ::(start: Long): SDIndex = def ->(end: Long): SDIndex =
SDIndex.interval(start, end) SDIndex.interval(start, end)
} }

View File

@ -51,4 +51,12 @@ object Implicits {
implicit def LongToPoint(x: Long): SDIndex = implicit def LongToPoint(x: Long): SDIndex =
SDIndex.point(x) SDIndex.point(x)
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
val result = new SDIndexWrapper(start)
result
}
implicit def IntToPoint(x: Int): SDIndex =
SDIndex.point(x)
} }

View File

@ -48,6 +48,24 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
assert(extracted == expected) assert(extracted == expected)
} }
/*it should "be able to extract a part of 2d matrix with alternative syntax" in {
val ndArray =
Array(
Array(1, 2, 3),
Array(4, 5, 6),
Array(7, 8, 9)
).mkNDArray(ordering)
val extracted = ndArray(1 :: 3, 0 :: 2)
val expected =
Array(
Array(4, 5),
Array(7, 8)
).mkNDArray(ordering)
assert(extracted == expected)
}*/
it should "be able to extract a part of 2d matrix with double data" in { it should "be able to extract a part of 2d matrix with double data" in {
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
@ -171,6 +189,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
val ellipsised = ndArray(--->) val ellipsised = ndArray(--->)
assert(ellipsised == ndArray) assert(ellipsised == ndArray)
val ellipsised1 = ndArray(---)
assert(ellipsised1 == ndArray)
} }
it should "accept partially ellipsis indices" in { it should "accept partially ellipsis indices" in {

View File

@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers {
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
val slice2 = x(0 :: 2, ---).eval val slice2 = x(0 -> 2, ---).eval
slice1 shouldBe slice2 slice1 shouldBe slice2
} }
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2,
0 :: 1, 0 -> 1,
0 :: 2).eval 0 -> 2).eval
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval
} }
} }