[WIP] Random Uniform (#36)
* args Signed-off-by: raver119@gmail.com <raver119@gmail.com> * T args Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									24980efde3
								
							
						
					
					
						commit
						51f3a1371d
					
				@ -44,13 +44,27 @@ namespace nd4j {
 | 
				
			|||||||
            if (block.getIArguments()->size())
 | 
					            if (block.getIArguments()->size())
 | 
				
			||||||
                dtype = (DataType)INT_ARG(0);
 | 
					                dtype = (DataType)INT_ARG(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr;
 | 
					            auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr;
 | 
				
			||||||
            auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr;
 | 
					            auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr;
 | 
				
			||||||
 | 
					            bool disposable = false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (min == nullptr && max == nullptr && block.numT() >= 2) {
 | 
				
			||||||
 | 
					                min = NDArrayFactory::create_('c', {}, dtype);
 | 
				
			||||||
 | 
					                max = NDArrayFactory::create_('c', {}, dtype);
 | 
				
			||||||
 | 
					                min->assign(T_ARG(0));
 | 
				
			||||||
 | 
					                max->assign(T_ARG(1));
 | 
				
			||||||
 | 
					                disposable = true;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            auto output = OUTPUT_VARIABLE(0);
 | 
					            auto output = OUTPUT_VARIABLE(0);
 | 
				
			||||||
            REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
 | 
					            REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
 | 
					            helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (disposable) {
 | 
				
			||||||
 | 
					                delete min;
 | 
				
			||||||
 | 
					                delete max;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
            return Status::OK();
 | 
					            return Status::OK();
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -150,10 +150,6 @@ namespace helpers {
 | 
				
			|||||||
    void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
 | 
					    void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
 | 
				
			||||||
        BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
 | 
					        BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
 | 
					 | 
				
			||||||
            graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user