Tests added
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
This commit is contained in:
		
							parent
							
								
									b36b3fa1aa
								
							
						
					
					
						commit
						4956048a11
					
				| @ -54,9 +54,9 @@ class SameDiffWrapper { | ||||
|     sd.placeHolder(name, dataType, shape: _*) | ||||
| } | ||||
| 
 | ||||
| case class SDIndexWrapper(start: Long) { | ||||
| case class SDIndexWrapper(end: Long) { | ||||
| 
 | ||||
|   def ::(end: Long): SDIndex = | ||||
|   def ::(start: Long): SDIndex = | ||||
|     SDIndex.interval(start, end) | ||||
| } | ||||
| 
 | ||||
| @ -75,18 +75,12 @@ class SDVariableWrapper { | ||||
|     thisVariable = variable | ||||
|   } | ||||
| 
 | ||||
|   // Indexing | ||||
|   def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) | ||||
| 
 | ||||
|   def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) | ||||
| 
 | ||||
|   /*def apply(x: SDIndex, y: SDIndex): SDVariable = | ||||
|     (x, y) match { | ||||
|       case (_, y) => thisVariable.get(SDIndex.all(), y) | ||||
|       case (x, _) => thisVariable.get(x, SDIndex.all()) | ||||
|       case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) | ||||
|       case (x, y) => thisVariable.get(x, y) | ||||
|     }*/ | ||||
| 
 | ||||
|   // Arithmetic | ||||
|   def add(other: Double): Unit = thisVariable.add(other) | ||||
| 
 | ||||
|   def *(other: SDVariable): SDVariable = | ||||
|  | ||||
| @ -48,4 +48,7 @@ object Implicits { | ||||
|     val result = new SDIndexWrapper(start) | ||||
|     result | ||||
|   } | ||||
| 
 | ||||
|   implicit def LongToPoint(x: Long): SDIndex = | ||||
|     SDIndex.point(x) | ||||
| } | ||||
|  | ||||
| @ -219,16 +219,19 @@ class MathTest extends FlatSpec with Matchers { | ||||
| 
 | ||||
|     val x = sd.bind(arr) | ||||
| 
 | ||||
|     println(x(SDIndex.point(0), ---).getArr) | ||||
|     x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval | ||||
| 
 | ||||
|     val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr | ||||
|     println(data1) | ||||
|     val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval | ||||
|     val slice2 = x(0 :: 2, ---).eval | ||||
|     slice1 shouldBe slice2 | ||||
|   } | ||||
| 
 | ||||
|     val data2 = x(0 :: 2, ---).getArr | ||||
|     println(data2) | ||||
|   "SDVariable " should "be indexable in 3d" in { | ||||
|     implicit val sd = SameDiff.create | ||||
| 
 | ||||
|     //assert(indices.indices == List(0, 1, 2, 3, 4, 5)) | ||||
|     //assert(indices.targetShape.toList == List(2, 3)) | ||||
|     val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2) | ||||
|     val x = sd.bind(arr) | ||||
| 
 | ||||
|     println(x(0,0,SDIndex.all()).eval) | ||||
|   } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user