reshape tweak (#275)
* - expand dims tweak - reshape memcpy Signed-off-by: raver119 <raver119@gmail.com> * validation fix Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									b686368b82
								
							
						
					
					
						commit
						5c806d2fb5
					
				| @ -1229,7 +1229,7 @@ | ||||
| 
 | ||||
| /// graph definitions
 | ||||
| #define REQUIRE_OK(A)  if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; | ||||
| #define REQUIRE_TRUE(...) if (nd4j::ops::conditionHelper(__FILE__, __LINE__, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed"); | ||||
| #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (nd4j::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; | ||||
| 
 | ||||
| #define DECLARE_ENTRY(NAME, ...)           template struct ND4J_EXPORT __registratorFloat<NAME<float>>; \ | ||||
|                                       template struct ND4J_EXPORT __registratorHalf<NAME<float16>>; \ | ||||
|  | ||||
| @ -41,9 +41,10 @@ namespace nd4j { | ||||
| 
 | ||||
|             REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis); | ||||
| 
 | ||||
|             std::vector<Nd4jLong> shape; | ||||
|             std::vector<Nd4jLong> shape(input->rankOf()); | ||||
| 
 | ||||
|             for(int e = 0; e < input->rankOf(); e++) | ||||
|                 shape.emplace_back(input->sizeAt(e)); | ||||
|                 shape[input->sizeAt(e)]; | ||||
| 
 | ||||
|             shape.insert(shape.begin() + axis, 1); | ||||
| 
 | ||||
|  | ||||
| @ -122,13 +122,18 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { | ||||
|             nd4j_printv("Reshape: new shape", shapeNew); | ||||
|         } | ||||
| 
 | ||||
|         if (s->isEmpty()) { | ||||
|         if (s->isScalar()) { | ||||
|             // just a scalar
 | ||||
|             z->assign(x); | ||||
|         } else { | ||||
|             // in some cases we might go away with simple memcpy call instead of assign call
 | ||||
|             if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) { | ||||
|                 z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset()); | ||||
|             } else { | ||||
|                 auto xr = x->reshape(order, shapeNew); | ||||
|                 z->assign(xr); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return Status::OK(); | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user