Operator :: for INDArray
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
0de8523ebd
commit
be70bea359
|
@ -311,6 +311,10 @@ object Implicits {
|
||||||
def -> = IntRangeFrom(underlying)
|
def -> = IntRangeFrom(underlying)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
implicit class IntRangeFromGen1(val underlying: Int) extends AnyVal {
|
||||||
|
def :: = IntRangeFromReverse(underlying)
|
||||||
|
}
|
||||||
|
|
||||||
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
|
implicit class IndexRangeWrapper(val underlying: Range) extends IndexNumberRange {
|
||||||
protected[nd4s] override def asRange(max: => Int): DRange =
|
protected[nd4s] override def asRange(max: => Int): DRange =
|
||||||
DRange.from(underlying, max)
|
DRange.from(underlying, max)
|
||||||
|
@ -381,17 +385,27 @@ object IndexNumberRange {
|
||||||
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
|
val endExclusive = if (endR >= 0) endR + diff else max + endR + diff
|
||||||
(start, endExclusive)
|
(start, endExclusive)
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArrayIndex.interval(start, step, end, false)
|
NDArrayIndex.interval(start, step, end, false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed trait IndexRange {
|
/*sealed*/
|
||||||
|
trait IndexRange {
|
||||||
def hasNegative: Boolean
|
def hasNegative: Boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
case class IntRangeFrom(underlying: Int) extends IndexRange {
|
case class IntRangeFrom(underlying: Int) extends IndexRange {
|
||||||
def apply[T](a: T): (Int, T) = (underlying, a)
|
def apply[T](a: T): (Int, T) =
|
||||||
|
(underlying, a)
|
||||||
|
|
||||||
|
override def toString: String = s"$underlying->"
|
||||||
|
|
||||||
|
override def hasNegative: Boolean = false
|
||||||
|
}
|
||||||
|
|
||||||
|
case class IntRangeFromReverse(underlying: Int) extends IndexRange {
|
||||||
|
def apply[T](a: T): (T, Int) =
|
||||||
|
(a, underlying)
|
||||||
|
|
||||||
override def toString: String = s"$underlying->"
|
override def toString: String = s"$underlying->"
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,12 @@ import org.slf4j.LoggerFactory
|
||||||
|
|
||||||
import _root_.scala.annotation.tailrec
|
import _root_.scala.annotation.tailrec
|
||||||
|
|
||||||
|
package object ops {
|
||||||
|
case object :: extends IndexRange {
|
||||||
|
override def hasNegative: Boolean = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
trait SliceableNDArray[A <: INDArray] {
|
trait SliceableNDArray[A <: INDArray] {
|
||||||
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
|
lazy val log = LoggerFactory.getLogger(classOf[SliceableNDArray[A]])
|
||||||
val underlying: A
|
val underlying: A
|
||||||
|
@ -68,6 +74,8 @@ trait SliceableNDArray[A <: INDArray] {
|
||||||
|
|
||||||
@tailrec
|
@tailrec
|
||||||
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
|
def modifyTargetIndices(input: List[IndexRange], i: Int, acc: List[DRange]): List[DRange] = input match {
|
||||||
|
case ops.:: :: t =>
|
||||||
|
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||||
case -> :: t =>
|
case -> :: t =>
|
||||||
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
modifyTargetIndices(t, i + 1, DRange(0, originalShape(i), 1) :: acc)
|
||||||
case ---> :: t =>
|
case ---> :: t =>
|
||||||
|
|
|
@ -54,9 +54,15 @@ class SameDiffWrapper {
|
||||||
sd.placeHolder(name, dataType, shape: _*)
|
sd.placeHolder(name, dataType, shape: _*)
|
||||||
}
|
}
|
||||||
|
|
||||||
case class SDIndexWrapper(start: Long) {
|
case class SDIndexWrapper(end: Long) {
|
||||||
|
|
||||||
def ->(end: Long): SDIndex =
|
def ::(start: Long): SDIndex =
|
||||||
|
SDIndex.interval(start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
case class SDIndexWrapper1(start: Int) {
|
||||||
|
|
||||||
|
def ::(end: Int): SDIndex =
|
||||||
SDIndex.interval(start, end)
|
SDIndex.interval(start, end)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ package org.nd4s
|
||||||
|
|
||||||
import org.nd4s.Implicits._
|
import org.nd4s.Implicits._
|
||||||
import org.scalatest.FlatSpec
|
import org.scalatest.FlatSpec
|
||||||
|
import org.nd4s.ops.::
|
||||||
|
|
||||||
class NDArrayExtractionInCOrderingTest extends NDArrayExtractionTestBase with COrderingForTest
|
class NDArrayExtractionInCOrderingTest extends NDArrayExtractionTestBase with COrderingForTest
|
||||||
class NDArrayExtractionInFortranOrderingTest extends NDArrayExtractionTestBase with FortranOrderingForTest
|
class NDArrayExtractionInFortranOrderingTest extends NDArrayExtractionTestBase with FortranOrderingForTest
|
||||||
|
@ -48,7 +49,7 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
assert(extracted == expected)
|
assert(extracted == expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
it should "be able to extract a part of 2d matrix with alternative syntax" in {
|
||||||
val ndArray =
|
val ndArray =
|
||||||
Array(
|
Array(
|
||||||
Array(1, 2, 3),
|
Array(1, 2, 3),
|
||||||
|
@ -64,7 +65,25 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
||||||
Array(7, 8)
|
Array(7, 8)
|
||||||
).mkNDArray(ordering)
|
).mkNDArray(ordering)
|
||||||
assert(extracted == expected)
|
assert(extracted == expected)
|
||||||
}*/
|
}
|
||||||
|
|
||||||
|
it should "be able to extract a part of 2d matrix with mixed syntax" in {
|
||||||
|
val ndArray =
|
||||||
|
Array(
|
||||||
|
Array(1, 2, 3),
|
||||||
|
Array(4, 5, 6),
|
||||||
|
Array(7, 8, 9)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
|
||||||
|
val extracted = ndArray(1 -> 3, 0 :: 2)
|
||||||
|
|
||||||
|
val expected =
|
||||||
|
Array(
|
||||||
|
Array(4, 5),
|
||||||
|
Array(7, 8)
|
||||||
|
).mkNDArray(ordering)
|
||||||
|
assert(extracted == expected)
|
||||||
|
}
|
||||||
|
|
||||||
it should "be able to extract a part of 2d matrix with double data" in {
|
it should "be able to extract a part of 2d matrix with double data" in {
|
||||||
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C)
|
||||||
|
|
|
@ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval
|
||||||
|
|
||||||
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval
|
||||||
val slice2 = x(0 -> 2, ---).eval
|
val slice2 = x(0 :: 2, ---).eval
|
||||||
slice1 shouldBe slice2
|
slice1 shouldBe slice2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers {
|
||||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval
|
||||||
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval
|
||||||
|
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval
|
x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2,
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2,
|
||||||
0 -> 1,
|
0 :: 1,
|
||||||
0 -> 2).eval
|
0 :: 2).eval
|
||||||
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval
|
x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue