Indexing syntax changed
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
f515151f5f
commit
0de8523ebd
|
@ -276,6 +276,10 @@ object Implicits {
|
|||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
case object --- extends IndexRange {
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
|
||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||
DRange(underlying, underlying, true, 1, max)
|
||||
|
|
|
@ -137,6 +137,9 @@ trait SliceableNDArray[A <: INDArray] {
|
|||
case ---> :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case --- :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case IntRangeFrom(from: Int) :: t =>
|
||||
val max = originalShape(i)
|
||||
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)
|
||||
|
|
|
@ -54,9 +54,9 @@ class SameDiffWrapper {
|
|||
sd.placeHolder(name, dataType, shape: _*)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper(end: Long) {
|
||||
case class SDIndexWrapper(start: Long) {
|
||||
|
||||
def ::(start: Long): SDIndex =
|
||||
def ->(end: Long): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
|
|
|
@ -51,4 +51,12 @@ object Implicits {
|
|||
|
||||
implicit def LongToPoint(x: Long): SDIndex =
|
||||
SDIndex.point(x)
|
||||
|
||||
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
|
||||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
|
||||
implicit def IntToPoint(x: Int): SDIndex =
|
||||
SDIndex.point(x)
|
||||
}
|
||||
|
|
|
@ -48,6 +48,24 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
/*it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
||||
val ndArray =
|
||||
Array(
|
||||
Array(1, 2, 3),
|
||||
Array(4, 5, 6),
|
||||
Array(7, 8, 9)
|
||||
).mkNDArray(ordering)
|
||||
|
||||
val extracted = ndArray(1 :: 3, 0 :: 2)
|
||||
|
||||
val expected =
|
||||
Array(
|
||||
Array(4, 5),
|
||||
Array(7, 8)
|
||||
).mkNDArray(ordering)
|
||||
assert(extracted == expected)
|
||||
}*/
|
||||
|
||||
it should "be able to extract a part of 2d matrix with double data" in {
|
||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||
|
||||
|
@ -171,6 +189,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
|
||||
val ellipsised = ndArray(--->)
|
||||
assert(ellipsised == ndArray)
|
||||
|
||||
val ellipsised1 = ndArray(---)
|
||||
assert(ellipsised1 == ndArray)
|
||||
}
|
||||
|
||||
it should "accept partially ellipsis indices" in {
|
||||
|
|
|
@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers {
|
|||
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||
|
||||
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||
val slice2 = x(0 :: 2, ---).eval
|
||||
val slice2 = x(0 -> 2, ---).eval
|
||||
slice1 shouldBe slice2
|
||||
}
|
||||
|
||||
|
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
|
|||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
||||
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
|
||||
0 :: 1,
|
||||
0 :: 2).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2,
|
||||
0 -> 1,
|
||||
0 -> 2).eval
|
||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue