Indexing syntax changed
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
This commit is contained in:
		
							parent
							
								
									f515151f5f
								
							
						
					
					
						commit
						0de8523ebd
					
				| @ -276,6 +276,10 @@ object Implicits { | ||||
|     override def hasNegative: Boolean = false | ||||
|   } | ||||
| 
 | ||||
|   case object --- extends IndexRange { | ||||
|     override def hasNegative: Boolean = false | ||||
|   } | ||||
| 
 | ||||
|   implicit class IntRange(val underlying: Int) extends IndexNumberRange { | ||||
|     protected[nd4s] override def asRange(max: => Int): DRange = | ||||
|       DRange(underlying, underlying, true, 1, max) | ||||
|  | ||||
| @ -137,6 +137,9 @@ trait SliceableNDArray[A <: INDArray] { | ||||
|         case ---> :: t => | ||||
|           val ellipsised = List.fill(originalShape.length - i - t.size)(->) | ||||
|           modifyTargetIndices(ellipsised ::: t, i, acc) | ||||
|         case --- :: t => | ||||
|           val ellipsised = List.fill(originalShape.length - i - t.size)(->) | ||||
|           modifyTargetIndices(ellipsised ::: t, i, acc) | ||||
|         case IntRangeFrom(from: Int) :: t => | ||||
|           val max = originalShape(i) | ||||
|           modifyTargetIndices(t, i + 1, IndexNumberRange.toNDArrayIndex(from, max, false, 1, max) :: acc) | ||||
|  | ||||
| @ -54,9 +54,9 @@ class SameDiffWrapper { | ||||
|     sd.placeHolder(name, dataType, shape: _*) | ||||
| } | ||||
| 
 | ||||
| case class SDIndexWrapper(end: Long) { | ||||
| case class SDIndexWrapper(start: Long) { | ||||
| 
 | ||||
|   def ::(start: Long): SDIndex = | ||||
|   def ->(end: Long): SDIndex = | ||||
|     SDIndex.interval(start, end) | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -51,4 +51,12 @@ object Implicits { | ||||
| 
 | ||||
|   implicit def LongToPoint(x: Long): SDIndex = | ||||
|     SDIndex.point(x) | ||||
| 
 | ||||
|   implicit def IntRangeToWrapper(start: Int): SDIndexWrapper = { | ||||
|     val result = new SDIndexWrapper(start) | ||||
|     result | ||||
|   } | ||||
| 
 | ||||
|   implicit def IntToPoint(x: Int): SDIndex = | ||||
|     SDIndex.point(x) | ||||
| } | ||||
|  | ||||
| @ -48,6 +48,24 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => | ||||
|     assert(extracted == expected) | ||||
|   } | ||||
| 
 | ||||
|   /*it should "be able to extract a part of 2d matrix with alternative syntax" in { | ||||
|     val ndArray = | ||||
|       Array( | ||||
|         Array(1, 2, 3), | ||||
|         Array(4, 5, 6), | ||||
|         Array(7, 8, 9) | ||||
|       ).mkNDArray(ordering) | ||||
| 
 | ||||
|     val extracted = ndArray(1 :: 3, 0 :: 2) | ||||
| 
 | ||||
|     val expected = | ||||
|       Array( | ||||
|         Array(4, 5), | ||||
|         Array(7, 8) | ||||
|       ).mkNDArray(ordering) | ||||
|     assert(extracted == expected) | ||||
|   }*/ | ||||
| 
 | ||||
|   it should "be able to extract a part of 2d matrix with double data" in { | ||||
|     val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) | ||||
| 
 | ||||
| @ -171,6 +189,9 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => | ||||
| 
 | ||||
|     val ellipsised = ndArray(--->) | ||||
|     assert(ellipsised == ndArray) | ||||
| 
 | ||||
|     val ellipsised1 = ndArray(---) | ||||
|     assert(ellipsised1 == ndArray) | ||||
|   } | ||||
| 
 | ||||
|   it should "accept partially ellipsis indices" in { | ||||
|  | ||||
| @ -222,7 +222,7 @@ class MathTest extends FlatSpec with Matchers { | ||||
|     x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval | ||||
| 
 | ||||
|     val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval | ||||
|     val slice2 = x(0 :: 2, ---).eval | ||||
|     val slice2 = x(0 -> 2, ---).eval | ||||
|     slice1 shouldBe slice2 | ||||
|   } | ||||
| 
 | ||||
| @ -237,10 +237,10 @@ class MathTest extends FlatSpec with Matchers { | ||||
|     x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.all()).eval shouldBe x(0, 0, ---).eval | ||||
|     x.get(SDIndex.point(0), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0, 0, 0).eval | ||||
| 
 | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 :: 2, 0, 0).eval | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 :: 2, | ||||
|                                                                                                   0 :: 1, | ||||
|                                                                                                   0 :: 2).eval | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 :: 2, 0 :: 1, ---).eval | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.point(0), SDIndex.point(0)).eval shouldBe x(0 -> 2, 0, 0).eval | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.interval(0, 2)).eval shouldBe x(0 -> 2, | ||||
|                                                                                                   0 -> 1, | ||||
|                                                                                                   0 -> 2).eval | ||||
|     x.get(SDIndex.interval(0, 2), SDIndex.interval(0, 1), SDIndex.all()).eval shouldBe x(0 -> 2, 0 -> 1, ---).eval | ||||
|   } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user