C++ NPY (#233)
* import .npy files in C++ Signed-off-by: raver119 <raver119@gmail.com> * reuse existing method Signed-off-by: raver119 <raver119@gmail.com> * add CPU_FEATURES to static lib Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									c9ffb6cbec
								
							
						
					
					
						commit
						f3fa4fd632
					
				@ -336,7 +336,7 @@ elseif(CPU_BLAS)
 | 
				
			|||||||
    if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
 | 
					    if ("${LIBND4J_ALL_OPS}" AND "${LIBND4J_BUILD_MINIFIER}")
 | 
				
			||||||
        message(STATUS "Building minifier...")
 | 
					        message(STATUS "Building minifier...")
 | 
				
			||||||
        add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
 | 
					        add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
 | 
				
			||||||
        target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES})
 | 
					        target_link_libraries(minifier ${LIBND4J_NAME}static ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
 | 
				
			||||||
    endif()
 | 
					    endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
 | 
					    if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
 | 
				
			||||||
 | 
				
			|||||||
@ -108,6 +108,13 @@ namespace nd4j {
 | 
				
			|||||||
        template <typename T>
 | 
					        template <typename T>
 | 
				
			||||||
        static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
 | 
					        static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        /**
 | 
				
			||||||
 | 
					         * This method creates NDArray from .npy file
 | 
				
			||||||
 | 
					         * @param fileName
 | 
				
			||||||
 | 
					         * @return
 | 
				
			||||||
 | 
					         */
 | 
				
			||||||
 | 
					        static NDArray fromNpyFile(const char *fileName);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /**
 | 
					        /**
 | 
				
			||||||
         * This factory create array from utf8 string
 | 
					         * This factory create array from utf8 string
 | 
				
			||||||
         * @return NDArray default dataType UTF8
 | 
					         * @return NDArray default dataType UTF8
 | 
				
			||||||
 | 
				
			|||||||
@ -24,11 +24,15 @@
 | 
				
			|||||||
#include <exceptions/cuda_exception.h>
 | 
					#include <exceptions/cuda_exception.h>
 | 
				
			||||||
#include <ConstantHelper.h>
 | 
					#include <ConstantHelper.h>
 | 
				
			||||||
#include <ConstantShapeHelper.h>
 | 
					#include <ConstantShapeHelper.h>
 | 
				
			||||||
 | 
					#include <GraphExecutioner.h>
 | 
				
			||||||
#include <ShapeUtils.h>
 | 
					#include <ShapeUtils.h>
 | 
				
			||||||
#include <type_traits>
 | 
					#include <type_traits>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <StringUtils.h>
 | 
					#include <StringUtils.h>
 | 
				
			||||||
 | 
					#include <NativeOps.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace nd4j {
 | 
					namespace nd4j {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -688,4 +692,27 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char
 | 
				
			|||||||
          return NDArray( shape, string, dtype, context);
 | 
					          return NDArray( shape, string, dtype, context);
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      NDArray NDArrayFactory::fromNpyFile(const char *fileName) {
 | 
				
			||||||
 | 
					          auto size = nd4j::graph::getFileSize(fileName);
 | 
				
			||||||
 | 
					          if (size < 0)
 | 
				
			||||||
 | 
					              throw std::runtime_error("File doesn't exit");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          auto pNPY = reinterpret_cast<char*>(::numpyFromFile(std::string(fileName)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          auto nBuffer = reinterpret_cast<void*>(::dataPointForNumpy(pNPY));
 | 
				
			||||||
 | 
					          auto shape = reinterpret_cast<Nd4jLong *>(::shapeBufferForNumpy(pNPY));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          auto length = shape::length(shape);
 | 
				
			||||||
 | 
					          int8_t *buffer = nullptr;
 | 
				
			||||||
 | 
					          nd4j::memory::Workspace *workspace = nullptr;
 | 
				
			||||||
 | 
					          auto byteLen = length * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          ALLOCATE(buffer, workspace, byteLen, int8_t);
 | 
				
			||||||
 | 
					          memcpy(buffer, nBuffer, byteLen);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          free(pNPY);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          return NDArray(buffer, shape, LaunchContext::defaultContext(), true);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1295,3 +1295,13 @@ TEST_F(NDArrayTest2, test_subarray_followed_by_reshape_1) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    ASSERT_EQ(e, r);
 | 
					    ASSERT_EQ(e, r);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TEST_F(NDArrayTest2, test_numpy_import_1) {
 | 
				
			||||||
 | 
					    std::string fname("./resources/arr_3,4_float32.npy");
 | 
				
			||||||
 | 
					    auto exp = NDArrayFactory::create<float>('c', {3, 4});
 | 
				
			||||||
 | 
					    exp.linspace(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto array = NDArrayFactory::fromNpyFile(fname.c_str());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ASSERT_EQ(exp, array);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										
											BIN
										
									
								
								libnd4j/tests_cpu/resources/arr_3,4_float32.npy
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								libnd4j/tests_cpu/resources/arr_3,4_float32.npy
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user