Tests added
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
This commit is contained in:
		
							parent
							
								
									b36b3fa1aa
								
							
						
					
					
						commit
						4956048a11
					
				| @ -54,9 +54,9 @@ class SameDiffWrapper { | |||||||
|     sd.placeHolder(name, dataType, shape: _*) |     sd.placeHolder(name, dataType, shape: _*) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| case class SDIndexWrapper(start: Long) { | case class SDIndexWrapper(end: Long) { | ||||||
| 
 | 
 | ||||||
|   def ::(end: Long): SDIndex = |   def ::(start: Long): SDIndex = | ||||||
|     SDIndex.interval(start, end) |     SDIndex.interval(start, end) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -75,18 +75,12 @@ class SDVariableWrapper { | |||||||
|     thisVariable = variable |     thisVariable = variable | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // Indexing | ||||||
|   def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) |   def apply(index: Long): SDVariable = thisVariable.get(SDIndex.point(index)) | ||||||
| 
 | 
 | ||||||
|   def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) |   def apply(index: SDIndex*): SDVariable = thisVariable.get(index: _*) | ||||||
| 
 | 
 | ||||||
|   /*def apply(x: SDIndex, y: SDIndex): SDVariable = |   // Arithmetic | ||||||
|     (x, y) match { |  | ||||||
|       case (_, y) => thisVariable.get(SDIndex.all(), y) |  | ||||||
|       case (x, _) => thisVariable.get(x, SDIndex.all()) |  | ||||||
|       case (_, _) => thisVariable.get(SDIndex.all(), SDIndex.all()) |  | ||||||
|       case (x, y) => thisVariable.get(x, y) |  | ||||||
|     }*/ |  | ||||||
| 
 |  | ||||||
|   def add(other: Double): Unit = thisVariable.add(other) |   def add(other: Double): Unit = thisVariable.add(other) | ||||||
| 
 | 
 | ||||||
|   def *(other: SDVariable): SDVariable = |   def *(other: SDVariable): SDVariable = | ||||||
|  | |||||||
| @ -48,4 +48,7 @@ object Implicits { | |||||||
|     val result = new SDIndexWrapper(start) |     val result = new SDIndexWrapper(start) | ||||||
|     result |     result | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   implicit def LongToPoint(x: Long): SDIndex = | ||||||
|  |     SDIndex.point(x) | ||||||
| } | } | ||||||
|  | |||||||
| @ -219,16 +219,19 @@ class MathTest extends FlatSpec with Matchers { | |||||||
| 
 | 
 | ||||||
|     val x = sd.bind(arr) |     val x = sd.bind(arr) | ||||||
| 
 | 
 | ||||||
|     println(x(SDIndex.point(0), ---).getArr) |     x(0, ---).eval shouldBe x(SDIndex.point(0), SDIndex.all()).eval | ||||||
| 
 | 
 | ||||||
|     val data1 = x(SDIndex.interval(0: Long, 2: Long), ---).getArr |     val slice1 = x.get(SDIndex.interval(0, 2), SDIndex.all()).eval | ||||||
|     println(data1) |     val slice2 = x(0 :: 2, ---).eval | ||||||
|  |     slice1 shouldBe slice2 | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|     val data2 = x(0 :: 2, ---).getArr |   "SDVariable " should "be indexable in 3d" in { | ||||||
|     println(data2) |     implicit val sd = SameDiff.create | ||||||
| 
 | 
 | ||||||
|     //assert(indices.indices == List(0, 1, 2, 3, 4, 5)) |     val arr = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 18).reshape(3, 3, 2) | ||||||
|     //assert(indices.targetShape.toList == List(2, 3)) |     val x = sd.bind(arr) | ||||||
| 
 | 
 | ||||||
|  |     println(x(0,0,SDIndex.all()).eval) | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user