* Implementation for non_max_suppression_v3 was added. Initial version * Added check for overcome threshold. * Added definition for V3 method. * java remapping for NonMaxSuppressionV3 Signed-off-by: raver119 <raver119@gmail.com> * Fixed proporly processing of an empty output and test. * Refactored op to less threshold data to float. * Implemented cuda-based helper for non_max_suppression_v3 op. * Fixed fake_quant_with_min_max_vars op. * Fixed tests with float numbers. * - assert now stops execution - sortByKey/sortByValue now have input validation Signed-off-by: raver119 <raver119@gmail.com> * missing var Signed-off-by: raver119 <raver119@gmail.com> * Fixed proper processing for zero max_size inputs. * Refactored kernel callers. * Fixed return statement for logdet op helper. * Refactored unsorted segment SqrtN op. * get back 8 tail bytes on CUDA Signed-off-by: raver119 <raver119@gmail.com> * Refactored segment prod ops and helpers for cuda and tests. * Additional test. * CudaWorkspace tests updated for 8 tail bytes Signed-off-by: raver119 <raver119@gmail.com> * special atomic test Signed-off-by: raver119 <raver119@gmail.com> * atomicMul/atomicDiv fix for 16bit values Signed-off-by: raver119 <raver119@gmail.com> * Eliminated waste prints.
		
			
				
	
	
		
			60 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			60 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /*******************************************************************************
 | |
|  * Copyright (c) 2015-2018 Skymind, Inc.
 | |
|  *
 | |
|  * This program and the accompanying materials are made available under the
 | |
|  * terms of the Apache License, Version 2.0 which is available at
 | |
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | |
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | |
|  * License for the specific language governing permissions and limitations
 | |
|  * under the License.
 | |
|  *
 | |
|  * SPDX-License-Identifier: Apache-2.0
 | |
|  ******************************************************************************/
 | |
| 
 | |
| //
 | |
| // @author raver119@gmail.com
 | |
| //
 | |
| 
 | |
| #include "testlayers.h"
 | |
| #include <NDArray.h>
 | |
| #include <Workspace.h>
 | |
| #include <MemoryRegistrator.h>
 | |
| #include <MmulHelper.h>
 | |
| 
 | |
| using namespace nd4j;
 | |
| using namespace nd4j::memory;
 | |
| 
 | |
| class CudaWorkspaceTests : public testing::Test {
 | |
| 
 | |
| };
 | |
| 
 | |
| TEST_F(CudaWorkspaceTests, Basic_Tests_1) {
 | |
|     Workspace workspace(65536, 65536);
 | |
| 
 | |
|     ASSERT_EQ(0, workspace.getCurrentOffset());
 | |
|     LaunchContext ctx;
 | |
|     ctx.setWorkspace(&workspace);
 | |
|     auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);
 | |
| 
 | |
|     ASSERT_EQ(108, workspace.getCurrentOffset());
 | |
|     ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
 | |
| 
 | |
|     array.e<int>(0);
 | |
| 
 | |
|     ASSERT_EQ(100, workspace.getCurrentSecondaryOffset());
 | |
| }
 | |
| 
 | |
| TEST_F(CudaWorkspaceTests, Basic_Tests_2) {
 | |
|     Workspace workspace(65536, 65536);
 | |
| 
 | |
|     ASSERT_EQ(0, workspace.getCurrentOffset());
 | |
|     LaunchContext ctx;
 | |
|     ctx.setWorkspace(&workspace);
 | |
|     auto array = NDArrayFactory::create<float>('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx);
 | |
| 
 | |
|     ASSERT_EQ(108, workspace.getCurrentOffset());
 | |
|     ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
 | |
| } |