Indexing syntax changed
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
f515151f5f
commit
0de8523ebd
|
@ -276,6 +276,10 @@ object Implicits {
|
||||||
override def hasNegative: Boolean = false
|
override def hasNegative: Boolean = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case object --- extends IndexRange {
|
||||||
|
override def hasNegative: Boolean = false
|
||||||
|
}
|
||||||
|
|
||||||
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
|
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
|
||||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||||
DRange(underlying, underlying, true, 1, max)
|
DRange(underlying, underlying, true, 1, max)
|
||||||
|
|
|
@ -137,6 +137,9 @@ trait SliceableNDArray[A <: INDArray] {
|
||||||
case ---> :: t =>
|
case ---> :: t =>
|
||||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||||
|
case --- :: t =>
|
||||||
|
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||||
|
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||||
case IntRangeFrom(from: Int) :: t =>
|
case IntRangeFrom(from: Int) :: t =>
|
||||||
val max = originalShape(i)
|
val max = originalShape(i)
|
||||||
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)
|
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)
|
||||||
|
|
|
@ -54,9 +54,9 @@ class SameDiffWrapper {
|
||||||
sd.placeHolder(name, dataType, shape: _*)
|
sd.placeHolder(name, dataType, shape: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class SDIndexWrapper(end: Long) {
|
case class SDIndexWrapper(start: Long) {
|
||||||
|
|
||||||
def ::(start: Long): SDIndex =
|
def ->(end: Long): SDIndex =
|
||||||
SDIndex.interval(start, end)
|
SDIndex.interval(start, end)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,4 +51,12 @@ object Implicits {
|
||||||
|
|
||||||
implicit def LongToPoint(x: Long): SDIndex =
|
implicit def LongToPoint(x: Long): SDIndex =
|
||||||
SDIndex.point(x)
|
SDIndex.point(x)
|
||||||
|
|
||||||
|
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
|
||||||
|
val result = new SDIndexWrapper(start)
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def IntToPoint(x: Int): SDIndex =
|
||||||
|
SDIndex.point(x)
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,6 +48,24 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
assert(extracted == expected)
|
assert(extracted == expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2, 3),
|
||||||
|
Array(4, 5, 6),
|
||||||
|
Array(7, 8, 9)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
val extracted = ndArray(1 :: 3, 0 :: 2)
|
||||||
|
|
||||||
|
val expected =
|
||||||
|
Array(
|
||||||
|
Array(4, 5),
|
||||||
|
Array(7, 8)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
assert(extracted == expected)
|
||||||
|
}*/
|
||||||
|
|
||||||
it should "be able to extract a part of 2d matrix with double data" in {
|
it should "be able to extract a part of 2d matrix with double data" in {
|
||||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||||
|
|
||||||
|
@ -171,6 +189,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
|
|
||||||
val ellipsised = ndArray(--->)
|
val ellipsised = ndArray(--->)
|
||||||
assert(ellipsised == ndArray)
|
assert(ellipsised == ndArray)
|
||||||
|
|
||||||
|
val ellipsised1 = ndArray(---)
|
||||||
|
assert(ellipsised1 == ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
it should "accept partially ellipsis indices" in {
|
it should "accept partially ellipsis indices" in {
|
||||||
|
|
|
@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||||
|
|
||||||
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||||
val slice2 = x(0 :: 2, ---).eval
|
val slice2 = x(0 -> 2, ---).eval
|
||||||
slice1 shouldBe slice2
|
slice1 shouldBe slice2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
||||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
||||||
|
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
|
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2,
|
||||||
0 :: 1,
|
0 -> 1,
|
||||||
0 :: 2).eval
|
0 -> 2).eval
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue