commit
fa8105bf0f
|
@ -276,6 +276,10 @@ object Implicits {
|
|||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
case object --- extends IndexRange {
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
implicit class IntRange(val underlying: Int) extends IndexNumberRange {
|
||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||
DRange(underlying, underlying, true, 1, max)
|
||||
|
@ -307,6 +311,10 @@ object Implicits {
|
|||
def -> = IntRangeFrom(underlying)
|
||||
}
|
||||
|
||||
implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal {
|
||||
def :: = IntRangeFromReverse(underlying)
|
||||
}
|
||||
|
||||
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
|
||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||
DRange.from(underlying, max)
|
||||
|
@ -377,17 +385,27 @@ object IndexNumberRange {
|
|||
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
|
||||
(start, endExclusive)
|
||||
}
|
||||
|
||||
NDArrayIndex.interval(start, step, end, false)
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait IndexRange {
|
||||
/*sealed*/
|
||||
trait IndexRange {
|
||||
def hasNegative: Boolean
|
||||
}
|
||||
|
||||
case class IntRangeFrom(underlying: Int) extends IndexRange {
|
||||
def apply[T](a: T): (Int, T) = (underlying, a)
|
||||
def apply[T](a: T): (Int, T) =
|
||||
(underlying, a)
|
||||
|
||||
override def toString: String = s"$underlying->"
|
||||
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
|
||||
case class IntRangeFromReverse(underlying: Int) extends IndexRange {
|
||||
def apply[T](a: T): (T, Int) =
|
||||
(a, underlying)
|
||||
|
||||
override def toString: String = s"$underlying->"
|
||||
|
||||
|
|
|
@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory
|
|||
|
||||
import _root_.scala.annotation.tailrec
|
||||
|
||||
package object ops {
|
||||
case object :: extends IndexRange {
|
||||
override def hasNegative: Boolean = false
|
||||
}
|
||||
}
|
||||
|
||||
trait SliceableNDArray[A <: INDArray] {
|
||||
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
|
||||
val underlying: A
|
||||
|
@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] {
|
|||
|
||||
@tailrec
|
||||
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
|
||||
case ops.:: :: t =>
|
||||
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||
case -> :: t =>
|
||||
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||
case ---> :: t =>
|
||||
|
@ -137,6 +145,9 @@ trait SliceableNDArray[A <: INDArray] {
|
|||
case ---> :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case --- :: t =>
|
||||
val ellipsised = List.fill(originalShape.length - i - t.size)(->)
|
||||
modifyTargetIndices(ellipsised ::: t, i, acc)
|
||||
case IntRangeFrom(from: Int) :: t =>
|
||||
val max = originalShape(i)
|
||||
modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc)
|
||||
|
|
|
@ -60,6 +60,12 @@ case class SDIndexWrapper(end: Long) {
|
|||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
case class SDIndexWrapper1(start: Int) {
|
||||
|
||||
def ::(end: Int): SDIndex =
|
||||
SDIndex.interval(start, end)
|
||||
}
|
||||
|
||||
object --- extends SDIndex {
|
||||
val thisIndex: SDIndex = SDIndex.all()
|
||||
}
|
||||
|
|
|
@ -51,4 +51,12 @@ object Implicits {
|
|||
|
||||
implicit def LongToPoint(x: Long): SDIndex =
|
||||
SDIndex.point(x)
|
||||
|
||||
implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = {
|
||||
val result = new SDIndexWrapper(start)
|
||||
result
|
||||
}
|
||||
|
||||
implicit def IntToPoint(x: Int): SDIndex =
|
||||
SDIndex.point(x)
|
||||
}
|
||||
|
|
|
@ -48,6 +48,42 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
||||
val ndArray =
|
||||
Array(
|
||||
Array(1, 2, 3),
|
||||
Array(4, 5, 6),
|
||||
Array(7, 8, 9)
|
||||
).mkNDArray(ordering)
|
||||
|
||||
val extracted = ndArray(1 :: 3, 0 :: 2)
|
||||
|
||||
val expected =
|
||||
Array(
|
||||
Array(4, 5),
|
||||
Array(7, 8)
|
||||
).mkNDArray(ordering)
|
||||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with mixed syntax" in {
|
||||
val ndArray =
|
||||
Array(
|
||||
Array(1, 2, 3),
|
||||
Array(4, 5, 6),
|
||||
Array(7, 8, 9)
|
||||
).mkNDArray(ordering)
|
||||
|
||||
val extracted = ndArray(1 -> 3, 0 :: 2)
|
||||
|
||||
val expected =
|
||||
Array(
|
||||
Array(4, 5),
|
||||
Array(7, 8)
|
||||
).mkNDArray(ordering)
|
||||
assert(extracted == expected)
|
||||
}
|
||||
|
||||
it should "be able to extract a part of 2d matrix with double data" in {
|
||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||
|
||||
|
@ -171,6 +207,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
|
||||
val ellipsised = ndArray(--->)
|
||||
assert(ellipsised == ndArray)
|
||||
|
||||
val ellipsised1 = ndArray(---)
|
||||
assert(ellipsised1 == ndArray)
|
||||
}
|
||||
|
||||
it should "accept partially ellipsis indices" in {
|
||||
|
|
Loading…
Reference in New Issue