Merge pull request #9229 from eclipse/ag_ctc_loss_3

Update codegen, add ctc loss to samediff
master
Adam Gibson 2021-03-12 21:31:47 +09:00 committed by GitHub
commit 0c81654567
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 1793 additions and 930 deletions

View File

@ -0,0 +1,614 @@
# ND4J Op Definitions and Code Generation
This project contains the ND4J Op definitions, the DSL (Domain Specific Language) that is used for those definitions and
code generators that use those definitions to create the actual Java code that is used to use the defined operations.
## Why define ops externally?
As we started to support SameDiff, we also started to introduce inconsistencies between SameDiff and ND4J. Even though
both of those libraries use the same underlying implementations for operations, there are both small and large
differences in the API that we provide for them. Sometimes, we have provided an official API only for one usage, and not
the other. And very often the documentation for a single op is in many different places.
In the future we want to support other programming languages with libnd4j, and provide more ways to use our C++ backend.
This would only increase the aforementioned problems.
The root of all of those problems, is that Ops are used across different environments, and there is no single way of
defining them with an enforced interface.
## How does this project help with enforcing a single consistent interface for ops?
The solution we propose, is to define the operations separately, and then generate the necessary API code for them. All
of the generated code is to be considered untouchable, editing it will result in the changes being overwritten sooner
rather than later.
The combination of external op definition and code generation, opens up many opportunities for us. The first one being
that we can easily create consistent APIs for both ND4J and SameDiff in Java. But, looking into the future, we can also
create those APIs for other programming languages like Python, Swift, or even C#. We can even go beyond programming
languages, and use the op definitions to create better documentation than what JavaDoc or similar might support out of
the box.
## Maintenance
This project is currently maintained by Paul Dubs, with feedback often collected from raver119 and Alex Black.
## Current Status
At the moment we still focus on nailing down an easily readable and contribution friendly DSL for op definition and code
generation that can replace namespace definitions. This means that at the moment we still rely on the pre-existing Op
definition classes that already exist in ND4J.
## Roadmap
* Replace Bitwise and Random namespaces with autogenerated code In progress.
* Implement a convenient CLI tool.
* Define all Ops using the DSL.
* Automatically generate derivative op declarations from existing ops
* Replace all namespace definitions in ND4J / SameDiff with automatically generated ones
* Replace all Op classes with automatically generated ones.
# Usage
Pre-requisites:
* JDK 8 or higher
* Maven 3.3 or higher
TODO: Show usage output of the project itself
TODO: Show how to use from mvn
## Generating Code - ND4J Namespaces
A script - `generate.sh` - is provided in the project root. This can be used (at present) to generate ND4J namespace classes.
It is assumed that the deeplearning4j mono repo and the dl4j-dev-tools repo both exist and have a common parent directory
i.e., `somedir/deeplearning4j` and `somedir/dl4j-dev-tools` both exist.
The script takes as argument the name (or names) of the ND4J namespaces to generate (not case sensitive) and projects (supported
projects are nd4j, sd and both by default).
As of 26/11, namespaces names (and hence valid args) include: `bitwise`, `neuralnetwork`, `random`, and `math`
Note also that `all` may be passed to the script to generate all namespaces.
For example, to generate both bitwise and random namespaces for both nd4j and SameDiff:
```
./generate.sh bitwise,random
```
Or to generate all namespaces for both nd4j and SameDiff, use:
```
./generate.sh all
```
To generate namespaces for one project only, use:
```
./generate.sh linalg -projects sd
```
or:
```
./generate.sh linalg -projects nd4j
```
The script will first compile the project, before running.
Internally, the `org.nd4j.codegen.cli.CLI` class is used.
Classes are written to `deeplearning4j/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/`
## Generating documentation.
It is possible to use generate.sh for generation of code only, docs in markdown format only, or both docs and code.
To generate docs only and store them to new folder "docs" for all namespaces:
```
./generate.sh all -docsdir ../../docs
```
Generation for selected namespaces works in the same way as for code:
```
./generate.sh -docsdir ../../docs bitwise,linalg
```
# Code structure
The project is implemented using a mix of Java and Kotlin. The DSL definition and the accompanying data structures are
implemented in Kotlin. At the moment the code generators are implemented in Java, in order to allow people who are not
fluent in Kotlin, but know Java to be able to contribute to the code generators.
The source code for this project is structured a bit different that what you would typically see in a Java or Kotlin
project. When you take a look inside the `src/main` directory, you will find 4 sub-directories.
The `java` and `kotlin` directories contain Java and Kotlin code respectively.
In order to not confuse op definitions with the machinery that allows them to be defined in that way, ops are kept in a
separate folder called `ops`.
Because we use JavaPoet for Java code generator implementation, we also have yet another folder called `stubs`. That
folder contains stub classes, that are used to reference other classes available in ND4J. These stub classes are
intentionally left empty, as JavaPoet only requires them for naming and automatically creating proper imports. We use
stub classes instead of depending on the actual nd4j API in order to break a cyclic dependency that would otherwise be
created (i.e. in order to be able to generate code for ND4J, we would need an already compiled nd4j to be available).
**Note:** If something is stubbed here and is moved in ND4J, then it also has to be moved to the appropriate place here,
otherwise the generated code will be wrong.
The `adr` folder contains "Architecture Decision Records". These files give you more insight into the "why" of some of
the bigger decisions within this project.
# DSL for Op Definition
Ops are defined using a DSL that is implemented in Kotlin. This means that other than the DSL, as defined in the
following, you can also use all of Kotlin when defining Ops. However, doing things the obvious and clearly
understandable way is better than coming up with a clever way, so prefer to use the DSL as described if unsure.
```kotlin
val mathNs = Namespace("math") {
Op("add") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic"
Input(NUMERIC, "x") { description = "First input to add" }
Input(NUMERIC,"y") { count = AtLeast(1); description = "Second input to add" }
Arg(INT,"shape") { count = AtLeast(1); description = "shape" }
Output(NUMERIC, "z") { description = "Output (x+y)" }
Doc(Language.ANY, DocScope.ALL) {
"""
(From AddOp) Add op doc text that will appear everywhere - classes, constructors, op creators
""".trimIndent()
}
Doc(Language.ANY, DocScope.CLASS_DOC_ONLY) {
"Add op doc text that will appear in all class docs (javadoc etc)"
}
Doc(Language.ANY, DocScope.CONSTRUCTORS_ONLY) {
"Add op doc text for constructors only"
}
}
}
```
This example shows how a namespace is defined. Namespaces are at the top layer, and ops can only be defined within the
context of a namespace. This example namespace contains only a single op, called "add". If we wanted to add another op,
we would simply add it below the first.
As you can see, every op has to have a name, if you try to create one without a name, you will get a compile error.
Within the context of the op, we first set in which java package the op class can be found in, then define its inputs,
arguments and outputs and finally add some free form documentation about what that op is doing.
Like with the op itself, the inputs, arguments and outputs all have to have a name, but unlike the op, they also require
a type. Within their context, you can set a description and a count of how many parameters they can take respectively.
If an input, argument or output take anything else than exactly 1, they will be treated as arrays. Typically you would
use this to define ops like `concat` which can take multiple input tensors or ops that might take shape arguments.
## Examples
The following shows how a typical op definition looks like and how the generated Java code may look.
An op might be defined like this:
```kotlin
Op("binomial") {
javaPackage = "org.nd4j.linalg.api.ops.random.custom"
Arg(INT, "nTrials") { description = "Number of trials parameter for the binomial distribution" }
Arg(FLOATING_POINT, "p") { description = "Probability of success for each trial" }
Arg(INT, "shape") { count = AtLeast(1); description = "Shape of the new random SDVariable, as a 1D array" }
Output(NUMERIC, "output") { description = "new random SDVariable, where values are randomly sampled according to a Binomial distribution" }
Doc(Language.ANY, DocScope.ALL) {
"""
Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
with the specified number of trials and probability.
""".trimIndent()
}
}
```
The java code generator will create a method like the following for it:
```java
/**
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
* with the specified number of trials and probability.
*
* @param nTrials Number of trials parameter for the binomial distribution
* @param p Probability of success for each trial
* @param shape Shape of the new random SDVariable, as a 1D array (Size: AtLeast(min=1))
* @return output new random SDVariable, where values are randomly sampled according to a Binomial distribution (NUMERIC type)
*/
public static INDArray binomial(long nTrials, double p, long... shape) {
Preconditions.checkArgument(shape.length >= 1, "shape has incorrect count. Expected: AtLeast(min=1)");
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.BinomialOp(nTrials, p, shape))[0];
}
```
Or an op with some more constraints:
```kotlin
Op("and") {
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
val x = Input(INT, "x") { description = "First input array" }
val y = Input(INT, "y") { description = "Second input array" }
Constraint("Must be same types"){ sameType(x, y) }
Constraint("Must have broadcastable shapes"){ broadcastableShapes(x, y) }
Output(INT, "output"){ description = "Bitwise AND array" }
Doc(Language.ANY, DocScope.ALL){
"""
Bitwise AND operation. Supports broadcasting.
""".trimIndent()
}
}
```
will be converted to java like this:
```java
/**
* Bitwise AND operation. Supports broadcasting.
*
* Inputs must satisfy the following constraints:
* Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
*
* @param x First input array (INT type)
* @param y Second input array (INT type)
* @return output Bitwise AND array (INT type)
*/
public static INDArray and(INDArray x, INDArray y) {
NDValidation.validateInteger("and", x);
NDValidation.validateInteger("and", y);
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
Preconditions.checkArgument(isBroadcastableShapes(x, y), "Must have broadcastable shapes");
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.AndOp(x, y))[0];
}
```
# Full DSL Description
## Namespace
fun NamespaceName() = Namespace("name"){ /* Op definitions in namespace context */}
Defines a namespace.
## Op
Only available within a namespace context
Op("opName") { /* op properties in op context */ }
Op("anotherOp", mixin) { /* op properties in op context */ }
Op("anotherOp2", mixin, keepInputs=false) { /* op properties in op context */ }
Every op requires a namespace unique op name.
When defining an op, you can also pass a mixin that it should inherit initial properties from. This has the same effect
as using `useMixin(mixin)` as the very first thing in the op definition. If you don't want to inherit all of the
parameters of the mixin, you can pass the same additional configuration as you would pass to
`useMixin(mixin, ...options..)`. See [Mixin](#mixin) for more information.
### Op properties
* `javaPackage` (String): Package where the op is to be found in the java implementation.
* `javaOpClass` (String): Name of java op class if inconsistent with opName. Default: same as opName
* `libnd4jName` (String): The name the op has in libnd4j. Default: same as opName
## Mixin
Available in global context.
val mixin = Mixin("name"){ /* op properties in op context */ }
// within an op context:
useMixin(mixin)
useMixin(mixin, ...options...)
// When needing to access something from within the mixin
mixin.input("name")
mixin.arg("name")
mixin.config("name")
mixin.output("name")
Mixins provide the facility to share commonalities between Ops. You can think of it like inheritance, especially when
you declare the use of a mixin on Op definition. In contrast to normal (single) inheritance where only a single super
class is possible, the mixin mechanism allows to "inherit" from multiple sources.
You can define almost all the same things within a mixin that you can within an Op. The only things that *can not* be
configured within a mixin are Op `name`, `libnd4jName` and `javaOpClass`.
As mixins can be configured within the global context, you can share them across namespaces by defining them in their
own file. If a mixin is namespace specific, you can also define it within the namespace context.
Mixins are used either on definition as a parameter `Op("opname", mixin){...}`, or with `useMixin(mixin)` within the op
definition. While the former version only supports a single mixin, the latter version allows you to use as many mixins
as are required.
You can also build up mixins by using `useMixin(mixin)` inside a Mixin itself.
`useMixin(mixin, ...options...)` supports a few additional options: `keepInputs`, `keepArgs`, `keepConfigs`,
`keepOutputs`, `keepSignatures`, `keepDoc`, `keepConstraints`. They default to `true`. If you want to skip including
some of them, you simply set the parameter for it to `false`, e.g. `useMixin(mixin, keepDoc=false)`.
When using `useMixin(mixin)`, all definitions within the mixin are applied as if this invocation was replaced with the
content of the mixin itself. This means, that if you have already defined anything prior to using a mixin, the mixin's
definitions will be **after** the previously defined things. This can be very useful if the commonality between ops is
that they have a few trailing options.
If a named property or section is defined in both a mixin (or multiple mixins) and the op, then the **last** to define it will
win. Named properties are `legacy`, `javaPackage`, named sections are `Input`, `Arg`, `Output`, `Config`.
For example, assume you have `javaPackage` defined in both an op and a mixin. Then you can have the following two
cases:
First case:
```kotlin
Op("foo"){
useMixin(exampleMixin)
javaPackage = "some.example.package"
}
```
Second case:
```kotlin
Op("foo"){
javaPackage = "some.example.package"
useMixin(exampleMixin)
}
```
In the first case, the op will have the `javaPackage` value that is defined within the op. In the second case it will
have the `javaPackage` value defined in the mixin.
For inputs, args, outputs, it works similarly. Assume you have `Input(dataType, "a")` defined in both the mixin and the
op. Again you can have two cases:
First case:
```kotlin
Op("foo"){
useMixin(exampleMixin)
Input(NUMERIC, "a")
}
```
Second case:
```kotlin
Op("foo"){
Input(NUMERIC, "a")
useMixin(exampleMixin)
}
```
In the first case, it will overwrite the input from the mixin. In the second case, the mixin will overwrite that the
input from the op.
## Config
Only available within a namespace context
val nameConfig = Config("Name"){
/* input, arg, constraint, doc properties */
}
Every config requires a namespace unique name.
A config allows to define a configuration class, that can be used as a holder for complex properties of specific ops
which will be passed to an op as a parameter.
Similar to an op itself, it supports `Input`, `Arg`, `Constraint` and `Doc` definitions.
in order to use the config within an op you either use `useConfig(cfg)` or `val configRef = useConfig(cfg)`. The second
form allows you to reference the config.
Referencing the config allows to you reference its inputs and args by name: `configRef.input("name")` and
`configRef.arg("name")`. Also it allows you to use a config in a signature `Signature(a, b, c, configRef)`.
When default and shorthand signatures are used, configs will be always placed at the end.
If a config is defined but not used, an `IllegalStateException` will be thrown.
See also [ADR 0007 "Configuration Objects"](adr/0007-configuration_objects.md).
## Input
Available within an op, mixin and a config context
Input(FLOATING_POINT, "b"){ /* input properties in input context */ }
val a = Input(INT, "a"){ /* input properties in input context */ }
Inputs represent tensors. They are what the op will work on.
Every input requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
When defining an input, you can assign it to a variable in order to be able to reference it later on. You might want to
do this when defining constraints.
If you want an input to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
that the count is meant to be `Exactly(1)`.
### Input properties
* `description` (String): A short description what this input represents. Setting this is recommended.
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
* `defaultValue` (Input): use another input as the default if this isn't set explicitly. The data type of the other
input has to match the data type of this input. The other input may also have a default value.
## Argument
Available within an op, mixin and config context
Arg(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
val a = Arg(INT, "a"){ /* Arg properties in arg context */ }
Args represent arguments. They modify how the op works on its inputs.
Every arg requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
When defining an arg, you can assign it to a variable in order to be able to reference it later on. You might want to do
this when defining constraints.
If you want an arg to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
that the count is meant to be `Exactly(1)`.
Note (Java specific): If the last arg is defined to represent an array, it will be translated to a vararg parameter, e.g.
`Arg(INT, "a"){ count = AtLeast(1); description = "..." }` will be turned into `long... a`.
### Argument properties
* `description` (String): A short description what this argument represents. Setting this is recommended.
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
* `defaultValue` (null|Number|Boolean|int[]|double[]|boolean[]|Arg|TensorShapeValue|TensorDataTypeValue|String):
Use given value as default value, if this isn't explicitly set. Can refer to *inputs* and *outputs* using `x.shape()`
and `x.dataType()`. The given default values has to match the data type for this argument. May also refer to another
Arg, and that Arg may also have a default value. Default values based on outputs are treated like without a default
in SameDiff mode.
* `possibleValues` (String[]): only available when ENUM data type is used for the argument. Takes a list of possible
values for the Enum. If used in in abstract base op, the enum will only be created once. See also
[ADR 0006 "Op specific enums"](adr/0006-op_specific_enums.md).
## Output
Only available within an op and mixin context
Output(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
Every output requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
While outputs can be assigned to a variable, there is no intended use-case for it. In contrast to inputs and args,
outputs can not be used in constraints.
### Output properties
* `description` (String): A short description what this argument represents. Setting this is recommended.
## Signature
Only available within an op and mixin context
Signature(a,b,c)
Signature(a,b,c) { "Some Documentation" }
AllParamSignature()
AllDefaultParamSignature()
For some ops only specific signatures make sense, as for example some optional parameters may become required in the
presence of other optional parameters. This feature is mainly meant to help with the fact that not all programming
languages (e.g. Java) support default parameters. Each signature is meant to describe one overload in those languages.
See also [ADR 0005 "Optional parameters and signatures"](adr/0005-optional_parameters_and_signatures.md).
Signatures can also reference the output(s) of an op. Those signatures are only relevant in NDArray programming mode.
They are not to be generated in SameDiff mode.
`AllParamSignature()` and `AllDefaultParamSignature()` are short hands for `Signature(...all parameters...)` and
`Signature(...only parameters with no default values...)`. Their parameters include references to outputs unless
disabled using `withOutput=false` (e.g. `AllParamSignature(withOutput=false)`).
If no signature is specified for an op, it is treated as if `AllParamSignature()` and `AllDefaultParamSignature()` are
both specified.
Each signature must satisfy the condition, that all required parameters are listed there. If this condition is not
satisfied, an `IllegalStateException` will be thrown on construction.
## Documentation
Only available within an op and mixin context
Doc(Language.ANY, DocScope.ALL){
""" Some documentation
It can be multiline. And indented.
""".trimIndent()
}
Documentation can be language specific, and can be set to be only given at specific places. The documentation itself is
given as a string. Because Kotlin supports multiline strings along with proper indentation, we are using them directly
here.
Note: At the moment we are only creating java code, so the documentation can use JavaDoc syntax.
You can have multiple Doc definitions; they are treated as additive.
Any instances of the following values will be replaced when generating code:
* `%OPNAME%` -> operation name ("Add", "Sub", etc)
* `%LIBND4J_OPNAME%` -> libnd4j op name ("add", "sub", etc)
* `%INPUT_TYPE%` -> input / output type depending on the generated api, i.e. `SDVariable` for SameDiff and `INDArray`
for ND4J
See `DocTokens` class for more details.
## Constraints
Available within an op, mixin and a config context.
Constraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
BackendConstraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
Many ops expect their inputs and arguments to satisfy some specific rules. Those rules can be expressed with the
constraint system.
Constraints are to be enforced within the frontend language, while BackendConstraints are currently only to be used as
a part of the documentation. They will be enforced within the C++ backend, so there is no point in double checking them.
There is a system in place to define even complex constraints for inputs and arguments.
In a constraint definition, you can reference inputs and arguments directly, if they are previously assigned to
a variable using `val name = Input(...)`. Inside the Constraint block, you can use the following operations:
* `eq`: Compare equality (applicable to numbers and booleans), e.g. `x eq 7`, `x eq true`
* `neq`: Compare inequality (applicable to numbers and booleans), e.g. `x neq 3`, `x neq true`
* `lt`, `lte`: less than, less than equal (applicable to numbers), e.g. `x lt 3`, `x lte 4`
* `gt`, `gte`: greater than, grater than equal (applicable to numbers), e.g. `x gt 5`, `x gte 6`
* `and`: combine two comparisons where both have to be true, e.g. `(x eq 8) and (y lt 3)`
* `or`: combine two comparisons where one has to be true, e.g. `(x eq 8) or (y eq true)`
* `all`: combine N comparisons where all have to be true, e.g. `all(x eq 8, y lt 3, z eq true)`
* `some`: combine N comparisons where at least one has to be true, e.g. `some(x eq 8, y lt 3, z eq true)`
* `not`: negates a comparison, e.g. `not(x eq 3)`
In addition to those operations, you also get access to some more complex constraints:
* `sameType(...)`: true if all given inputs are the same type, e.g. `sameType(x,y,z)`
* `sameShape(...)`: true if all given inputs have the same shape, e.g. `sameShape(x,y,z)`
* `broadcastableShapes(...)`: true if all given inputs have broadcast compatible shapes, e.g. `broadcastableShapes(x,y,z)`
Inputs also get some additional methods on them to define useful constraints:
* `input.rank()`: Rank of the given input
* `input.sizeAt(i)`: size of the given input at the i-th dimension
* `input.isScalar()`: Short hand for `x.rank() == 1`
### Examples
Some examples of constraints, and what they evaluate to. The example code contains a little bit of context.
```kotlin
val x = Input(INT, "x") { description = "First input array" }
val y = Input(INT, "y") { description = "Second input array" }
Constraint("foo bar"){
x.sizeAt(7) eq 7 and y.isScalar()
}
```
will evaluate to:
```java
Preconditions.checkArgument((x.sizeAt(7) == 7) && (y.rank() == 1), "foo bar");
```
More examples (only the constraint itself, without context code):
#### Some
```kotlin
some(input.rank() eq 3, input.sizeAt(2) gte 7, input.sizeAt(4) lt 5)
```
turns to:
```java
((x.rank() == 3) || (x.sizeAt(2) >= 7)) || (x.sizeAt(4) < 5)
```
# Contributing to this project
If you want to contribute to this project other than by adding or improving op definitions, the following sections might
be of special interest to you.
## Extending the DSL
The DSL is implemented using Kotlins type-safe builders feature
(see https://kotlinlang.org/docs/reference/type-safe-builders.html). The basic principle is that functions calls can
receive blocks that can be executed in a specified context. When combined with the fact that we are just looking to
create an object graph that is then going to be used as input to the code generators, this allows us to create a very
feature rich DSL without actually having to write a lot of code to support it.
Most of the DSL specific code can be found in `src/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt`. The actual class
definitions for the object graph we are building, can be found in `src/kotlin/org/nd4j/codegen/api`.
If you want to add just a simple field to one of the objects, it is usually enough to directly add it to the particular
class.
If you want to add a specific section to the op definition, i.e. a section like Input or Doc, you will have to add both
the class for the object that it is going to be creating, as well as a function within OpBuilder.kt to create and
register that section within the op.
**Note:** When you extend the DSL you will most likely also have to update all code generators to support the feature
you have added.
## Adding / extending code generators
Code generators can be written in either Java or Kotlin. Java has the advantage that more people will have experience in
using it. Kotlin has the advantage of more convenient syntax, especially for plain string manipulation and when dealing
with Enums and fixed sets of subclasses (called sealed classes in Kotlin).
All generators have to implement the `org.nd4j.codegen.api.generator.Generator` interface. For automatic detection by
the CLI tool, they should also be within the `org.nd4j.codegen.impl.LANGUAGE` package, where `LANGUAGE` is the actual
language that they generate.
Code generators can also use an auxiliary generator for constraint generation. Those auxiliary generators, have to
implement ` org.nd4j.codegen.api.generator.ConstraintCodeGenerator` interface.

View File

@ -0,0 +1,11 @@
#!/bin/bash
if test "$#" -eq 0; then
echo "No namespaces were specified. One or more namespaces must be provided as an argument"
echo "Usage example 1 (single namespace): ./generate.sh math"
echo "Usage example 2 (multiple namespaces): ./generate.sh math,random"
echo "Usage example 2 (all namespaces): ./generate.sh all"
else
mvn clean package -DskipTests
java -cp target/codegen-1.0.0-SNAPSHOT-shaded.jar org.nd4j.codegen.cli.CLI -dir ../../ -namespaces "$@"
fi

View File

@ -0,0 +1,258 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.nd4j</groupId>
<artifactId>codegen</artifactId>
<version>1.0.0-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<commonsio.version>2.5</commonsio.version>
<commons.lang.version>3.12.0</commons.lang.version>
<commons.dbutils.version>1.7</commons.dbutils.version>
<lombok.version>1.18.8</lombok.version>
<logback.version>1.1.7</logback.version>
<junit.version>4.12</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version>
<java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
<kotlin.version>1.3.50</kotlin.version>
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.28</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>${commonsio.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${commons.lang.version}</version>
</dependency>
<dependency>
<groupId>com.squareup</groupId>
<artifactId>javapoet</artifactId>
<version>1.11.1</version>
</dependency>
<!-- Test Dependencies -->
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<version>${junit-jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib-jdk8</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-test</artifactId>
<version>${kotlin.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-kotlin</artifactId>
<version>2.9.9</version>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>1.78</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>1.0.0-SNAPSHOT</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>3.0.0</version>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals><goal>add-source</goal></goals>
<configuration>
<sources>
<source>src/main/stubs</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>${maven-shade-plugin.version}</version>
<configuration>
<shadedArtifactAttached>true</shadedArtifactAttached>
<createDependencyReducedPom>false</createDependencyReducedPom>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>org/datanucleus/**</exclude>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
</transformers>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-plugin</artifactId>
<version>${kotlin.version}</version>
<configuration>
<args>
<arg>-Xjsr305=strict</arg>
</args>
<compilerPlugins>
<plugin>spring</plugin>
<plugin>jpa</plugin>
</compilerPlugins>
</configuration>
<dependencies>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-allopen</artifactId>
<version>${kotlin.version}</version>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-maven-noarg</artifactId>
<version>${kotlin.version}</version>
</dependency>
</dependencies>
<executions>
<execution>
<id>compile</id>
<goals> <goal>compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/main/java</sourceDir>
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
<execution>
<id>test-compile</id>
<goals> <goal>test-compile</goal> </goals>
<configuration>
<sourceDirs>
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
<sourceDir>${project.basedir}/src/test/java</sourceDir>
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
</sourceDirs>
</configuration>
</execution>
</executions>
</plugin>
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.5.1</version>
<executions>
<!-- Replacing default-compile as it is treated specially by maven -->
<execution>
<id>default-compile</id>
<phase>none</phase>
</execution>
<!-- Replacing default-testCompile as it is treated specially by maven -->
<execution>
<id>default-testCompile</id>
<phase>none</phase>
</execution>
<execution>
<id>java-compile</id>
<phase>compile</phase>
<goals> <goal>compile</goal> </goals>
</execution>
<execution>
<id>java-test-compile</id>
<phase>test-compile</phase>
<goals> <goal>testCompile</goal> </goals>
</execution>
</executions>
<configuration>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -19,17 +19,13 @@
*/
package org.nd4j.codegen.impl.java;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.codegen.api.Language;
import org.nd4j.codegen.api.Namespace;
import org.nd4j.codegen.api.NamespaceOps;
import org.nd4j.codegen.api.Op;
import org.nd4j.codegen.api.generator.Generator;
import org.nd4j.codegen.api.generator.GeneratorConfig;
import java.io.File;
import java.io.IOException;
public class JavaPoetGenerator implements Generator {
@ -40,12 +36,12 @@ public class JavaPoetGenerator implements Generator {
}
@Override
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.linalg.factory", StringUtils.EMPTY);
}
@Override
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
//throw new UnsupportedOperationException("Not yet implemented");
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.autodiff.samediff", StringUtils.EMPTY);
}

View File

@ -47,6 +47,23 @@ fun SDLoss() = Namespace("Loss"){
}
}
Op("ctcLoss") {
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
javaOpClass = "CtcLoss"
Input(NUMERIC, "targetLabels") { description = "Label array" }
Input(NUMERIC, "logitInput") { description = "Inputs" }
Input(NUMERIC, "targetLabelLengths") { description = "Length of the target label" }
Input(NUMERIC, "logitInputLengths") { description = "Length of the input"}
Output(NUMERIC, "output"){ description = "Ctc loss " }
Doc(Language.ANY, DocScope.ALL){
"""
CTC Loss: Connectionist Temporal Classification Loss. See:
https://dl.acm.org/citation.cfm?id=1143891
""".trimIndent()
}
}
Op("cosineDistance") {
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
javaOpClass = "CosineDistanceLoss"

View File

@ -1,32 +1,26 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -307,7 +301,7 @@ public class SDBaseOps {
* @param transposeB Whether to transpose B arrays or not
*/
public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable[] inputsB, boolean transposeA,
boolean transposeB) {
boolean transposeB) {
SDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
SDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
@ -331,7 +325,7 @@ public class SDBaseOps {
* @param transposeB Whether to transpose B arrays or not
*/
public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable[] inputsB,
boolean transposeA, boolean transposeB) {
boolean transposeA, boolean transposeB) {
SDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
SDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
@ -482,7 +476,7 @@ public class SDBaseOps {
* @return output Output variable (NUMERIC type)
*/
public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse,
int... axis) {
int... axis) {
SDValidation.validateNumerical("cumprod", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable();
@ -563,7 +557,7 @@ public class SDBaseOps {
* @return output (NUMERIC type)
*/
public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse,
int... axis) {
int... axis) {
SDValidation.validateNumerical("cumsum", "in", in);
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable();
@ -680,7 +674,7 @@ public class SDBaseOps {
* @param numPartitions Number of partitions, >= 1
*/
public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions,
int numPartitions) {
int numPartitions) {
SDValidation.validateNumerical("dynamicPartition", "x", x);
SDValidation.validateInteger("dynamicPartition", "partitions", partitions);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables();
@ -1189,7 +1183,7 @@ public class SDBaseOps {
* @return output INDArray with linearly spaced elements (NUMERIC type)
*/
public SDVariable linspace(String name, DataType dataType, double start, double stop,
long number) {
long number) {
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable();
return sd.updateVariableNameAndReference(out, name);
}
@ -1205,7 +1199,7 @@ public class SDBaseOps {
* @return output INDArray with linearly spaced elements (NUMERIC type)
*/
public SDVariable linspace(SDVariable start, SDVariable stop, SDVariable number,
DataType dataType) {
DataType dataType) {
SDValidation.validateNumerical("linspace", "start", start);
SDValidation.validateNumerical("linspace", "stop", stop);
SDValidation.validateInteger("linspace", "number", number);
@ -1224,7 +1218,7 @@ public class SDBaseOps {
* @return output INDArray with linearly spaced elements (NUMERIC type)
*/
public SDVariable linspace(String name, SDVariable start, SDVariable stop, SDVariable number,
DataType dataType) {
DataType dataType) {
SDValidation.validateNumerical("linspace", "start", start);
SDValidation.validateNumerical("linspace", "stop", stop);
SDValidation.validateInteger("linspace", "number", number);
@ -1445,7 +1439,7 @@ public class SDBaseOps {
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
*/
public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDim,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("matchConditionCount", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable();
@ -1469,7 +1463,7 @@ public class SDBaseOps {
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
*/
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition,
boolean keepDim, int... dimensions) {
boolean keepDim, int... dimensions) {
SDValidation.validateNumerical("matchConditionCount", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable();
@ -1514,7 +1508,7 @@ public class SDBaseOps {
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
*/
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("matchConditionCount", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable();
@ -1895,7 +1889,7 @@ public class SDBaseOps {
* @return output (NUMERIC type)
*/
public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
boolean transposeZ) {
boolean transposeZ) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
@ -1914,7 +1908,7 @@ public class SDBaseOps {
* @return output (NUMERIC type)
*/
public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX,
boolean transposeY, boolean transposeZ) {
boolean transposeY, boolean transposeZ) {
SDValidation.validateNumerical("mmul", "x", x);
SDValidation.validateNumerical("mmul", "y", y);
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
@ -2304,14 +2298,14 @@ public class SDBaseOps {
*
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
* @param depth Number of classes
* @param axis
* @param on
* @param off
* @param axis
* @param on
* @param off
* @param dataType Output data type
* @return output Output variable (NUMERIC type)
*/
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off,
DataType dataType) {
DataType dataType) {
SDValidation.validateNumerical("oneHot", "indices", indices);
return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable();
}
@ -2324,14 +2318,14 @@ public class SDBaseOps {
* @param name name May be null. Name for the output variable
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
* @param depth Number of classes
* @param axis
* @param on
* @param off
* @param axis
* @param on
* @param off
* @param dataType Output data type
* @return output Output variable (NUMERIC type)
*/
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on,
double off, DataType dataType) {
double off, DataType dataType) {
SDValidation.validateNumerical("oneHot", "indices", indices);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable();
return sd.updateVariableNameAndReference(out, name);
@ -2344,9 +2338,9 @@ public class SDBaseOps {
*
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
* @param depth Number of classes
* @param axis
* @param on
* @param off
* @param axis
* @param on
* @param off
* @return output Output variable (NUMERIC type)
*/
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) {
@ -2362,13 +2356,13 @@ public class SDBaseOps {
* @param name name May be null. Name for the output variable
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
* @param depth Number of classes
* @param axis
* @param on
* @param off
* @param axis
* @param on
* @param off
* @return output Output variable (NUMERIC type)
*/
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on,
double off) {
double off) {
SDValidation.validateNumerical("oneHot", "indices", indices);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable();
return sd.updateVariableNameAndReference(out, name);
@ -2436,7 +2430,7 @@ public class SDBaseOps {
* As per onesLike(String, SDVariable) but the output datatype may be specified<br>
*
* @param input (NUMERIC type)
* @param dataType
* @param dataType
* @return output (NUMERIC type)
*/
public SDVariable onesLike(SDVariable input, DataType dataType) {
@ -2449,7 +2443,7 @@ public class SDBaseOps {
*
* @param name name May be null. Name for the output variable
* @param input (NUMERIC type)
* @param dataType
* @param dataType
* @return output (NUMERIC type)
*/
public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
@ -2612,7 +2606,7 @@ public class SDBaseOps {
* @param from Initial/smallest value
* @param to Largest value (exclusive)
* @param step Step size
* @param dataType
* @param dataType
* @return output INDArray with the specified values (NUMERIC type)
*/
public SDVariable range(double from, double to, double step, DataType dataType) {
@ -2628,7 +2622,7 @@ public class SDBaseOps {
* @param from Initial/smallest value
* @param to Largest value (exclusive)
* @param step Step size
* @param dataType
* @param dataType
* @return output INDArray with the specified values (NUMERIC type)
*/
public SDVariable range(String name, double from, double to, double step, DataType dataType) {
@ -2644,7 +2638,7 @@ public class SDBaseOps {
* @param from Initial/smallest value (NUMERIC type)
* @param to Largest value (exclusive) (NUMERIC type)
* @param step Step size (NUMERIC type)
* @param dataType
* @param dataType
* @return output INDArray with the specified values (NUMERIC type)
*/
public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) {
@ -2663,11 +2657,11 @@ public class SDBaseOps {
* @param from Initial/smallest value (NUMERIC type)
* @param to Largest value (exclusive) (NUMERIC type)
* @param step Step size (NUMERIC type)
* @param dataType
* @param dataType
* @return output INDArray with the specified values (NUMERIC type)
*/
public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step,
DataType dataType) {
DataType dataType) {
SDValidation.validateNumerical("range", "from", from);
SDValidation.validateNumerical("range", "to", to);
SDValidation.validateNumerical("range", "step", step);
@ -2727,7 +2721,7 @@ public class SDBaseOps {
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
*/
public SDVariable replaceWhere(String name, SDVariable update, SDVariable from,
Condition condition) {
Condition condition) {
SDValidation.validateNumerical("replaceWhere", "update", update);
SDValidation.validateNumerical("replaceWhere", "from", from);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable();
@ -2761,7 +2755,7 @@ public class SDBaseOps {
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
*/
public SDVariable replaceWhere(String name, SDVariable update, double value,
Condition condition) {
Condition condition) {
SDValidation.validateNumerical("replaceWhere", "update", update);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable();
return sd.updateVariableNameAndReference(out, name);
@ -2799,47 +2793,6 @@ public class SDBaseOps {
return sd.updateVariableNameAndReference(out, name);
}
/**
* Split the input in to a list of sub tensors
* @param input the input to split
* @param numSizeSplits the number of splits
* @param splitDim the dimension to split along
* @return the set of output variables
*/
public SDVariable[] split(SDVariable input,int numSizeSplits,int splitDim) {
SDValidation.validateNumerical("split",input);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
return out;
}
/**
* Split the input in to a list of sub tensors
* @param name the potential name of the input
* @param input the input to split
* @param numSizeSplits the number of splits
* @param splitDim the dimension to split along
* @return the set of output variables
*/
public SDVariable[] split(String name,SDVariable input,int numSizeSplits,int splitDim) {
SDValidation.validateNumerical("split",input);
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
SDVariable[] ret = new SDVariable[out.length];
AtomicInteger index = new AtomicInteger(0);
Arrays.stream(out).forEach(output -> {
if(index.get() < 1) {
ret[index.get()] = sd.updateVariableNameAndReference(output,name);
index.incrementAndGet();
}
else {
ret[index.get()] = sd.updateVariableNameAndReference(output,name + ":" + index.get());
index.incrementAndGet();
}
});
return ret;
}
/**
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
* input, but with the specified shape.<br>
@ -2930,7 +2883,7 @@ public class SDBaseOps {
* @return output Reversed sequences (NUMERIC type)
*/
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim,
int batchDim) {
int batchDim) {
SDValidation.validateNumerical("reverseSequence", "x", x);
SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable();
@ -2947,7 +2900,7 @@ public class SDBaseOps {
* @return output Reversed sequences (NUMERIC type)
*/
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim,
int batchDim) {
int batchDim) {
SDValidation.validateNumerical("reverseSequence", "x", x);
SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable();
@ -3123,7 +3076,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterAdd", "ref", ref);
SDValidation.validateNumerical("scatterAdd", "indices", indices);
SDValidation.validateNumerical("scatterAdd", "updates", updates);
@ -3166,7 +3119,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterDiv", "ref", ref);
SDValidation.validateNumerical("scatterDiv", "indices", indices);
SDValidation.validateNumerical("scatterDiv", "updates", updates);
@ -3209,7 +3162,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterMax", "ref", ref);
SDValidation.validateNumerical("scatterMax", "indices", indices);
SDValidation.validateNumerical("scatterMax", "updates", updates);
@ -3252,7 +3205,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterMin", "ref", ref);
SDValidation.validateNumerical("scatterMin", "indices", indices);
SDValidation.validateNumerical("scatterMin", "updates", updates);
@ -3295,7 +3248,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterMul", "ref", ref);
SDValidation.validateNumerical("scatterMul", "indices", indices);
SDValidation.validateNumerical("scatterMul", "updates", updates);
@ -3338,7 +3291,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterSub", "ref", ref);
SDValidation.validateNumerical("scatterSub", "indices", indices);
SDValidation.validateNumerical("scatterSub", "updates", updates);
@ -3381,7 +3334,7 @@ public class SDBaseOps {
* @return output The updated variable (NUMERIC type)
*/
public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices,
SDVariable updates) {
SDVariable updates) {
SDValidation.validateNumerical("scatterUpdate", "ref", ref);
SDValidation.validateNumerical("scatterUpdate", "indices", indices);
SDValidation.validateNumerical("scatterUpdate", "updates", updates);
@ -3595,7 +3548,7 @@ public class SDBaseOps {
*
* @param lengths Lengths of the sequences (NUMERIC type)
* @param maxLen Maximum sequence length
* @param dataType
* @param dataType
* @return output Output variable (NUMERIC type)
*/
public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) {
@ -3610,7 +3563,7 @@ public class SDBaseOps {
* @param name name May be null. Name for the output variable
* @param lengths Lengths of the sequences (NUMERIC type)
* @param maxLen Maximum sequence length
* @param dataType
* @param dataType
* @return output Output variable (NUMERIC type)
*/
public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) {
@ -3625,7 +3578,7 @@ public class SDBaseOps {
*
* @param lengths Lengths of the sequences (NUMERIC type)
* @param maxLen Maximum sequence length (INT type)
* @param dataType
* @param dataType
* @return output Output variable (NUMERIC type)
*/
public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) {
@ -3641,11 +3594,11 @@ public class SDBaseOps {
* @param name name May be null. Name for the output variable
* @param lengths Lengths of the sequences (NUMERIC type)
* @param maxLen Maximum sequence length (INT type)
* @param dataType
* @param dataType
* @return output Output variable (NUMERIC type)
*/
public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen,
DataType dataType) {
DataType dataType) {
SDValidation.validateNumerical("sequenceMask", "lengths", lengths);
SDValidation.validateInteger("sequenceMask", "maxLen", maxLen);
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable();
@ -3656,7 +3609,7 @@ public class SDBaseOps {
* see sequenceMask(String, SDVariable, SDVariable, DataType)<br>
*
* @param lengths (NUMERIC type)
* @param dataType
* @param dataType
* @return output (NUMERIC type)
*/
public SDVariable sequenceMask(SDVariable lengths, DataType dataType) {
@ -3669,7 +3622,7 @@ public class SDBaseOps {
*
* @param name name May be null. Name for the output variable
* @param lengths (NUMERIC type)
* @param dataType
* @param dataType
* @return output (NUMERIC type)
*/
public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) {
@ -3857,7 +3810,7 @@ public class SDBaseOps {
* keepDims = false: [a,c]<br>
*
* @param x (NUMERIC type)
* @param keepDims
* @param keepDims
* @param dimensions (Size: AtLeast(min=0))
* @return output (NUMERIC type)
*/
@ -3879,7 +3832,7 @@ public class SDBaseOps {
*
* @param name name May be null. Name for the output variable
* @param x (NUMERIC type)
* @param keepDims
* @param keepDims
* @param dimensions (Size: AtLeast(min=0))
* @return output (NUMERIC type)
*/
@ -4015,7 +3968,7 @@ public class SDBaseOps {
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
*/
public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, boolean keepDims,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("standardDeviation", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
@ -4039,7 +3992,7 @@ public class SDBaseOps {
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
*/
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected,
boolean keepDims, int... dimensions) {
boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("standardDeviation", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
@ -4084,7 +4037,7 @@ public class SDBaseOps {
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
*/
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("standardDeviation", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable();
@ -4113,7 +4066,7 @@ public class SDBaseOps {
* @return output A subset of the input array (NUMERIC type)
*/
public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides,
int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
SDValidation.validateNumerical("stridedSlice", "in", in);
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
@ -4144,8 +4097,8 @@ public class SDBaseOps {
* @return output A subset of the input array (NUMERIC type)
*/
public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end,
long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask,
int shrinkAxisMask) {
long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask,
int shrinkAxisMask) {
SDValidation.validateNumerical("stridedSlice", "in", in);
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
@ -4196,7 +4149,7 @@ public class SDBaseOps {
* @return output A subset of the input array (NUMERIC type)
*/
public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end,
long... strides) {
long... strides) {
SDValidation.validateNumerical("stridedSlice", "in", in);
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
@ -4330,7 +4283,7 @@ public class SDBaseOps {
* @return output Output variable (NUMERIC type)
*/
public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int[] dimensionsY,
boolean transposeX, boolean transposeY, boolean transposeZ) {
boolean transposeX, boolean transposeY, boolean transposeZ) {
SDValidation.validateNumerical("tensorMmul", "x", x);
SDValidation.validateNumerical("tensorMmul", "y", y);
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
@ -4352,7 +4305,7 @@ public class SDBaseOps {
* @return output Output variable (NUMERIC type)
*/
public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX,
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
SDValidation.validateNumerical("tensorMmul", "x", x);
SDValidation.validateNumerical("tensorMmul", "y", y);
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
@ -4389,7 +4342,7 @@ public class SDBaseOps {
* @return output Output variable (NUMERIC type)
*/
public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX,
int... dimensionsY) {
int... dimensionsY) {
SDValidation.validateNumerical("tensorMmul", "x", x);
SDValidation.validateNumerical("tensorMmul", "y", y);
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
@ -4522,7 +4475,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentMax", "data", data);
SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable();
@ -4561,7 +4514,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentMean", "data", data);
SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable();
@ -4600,7 +4553,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentMin", "data", data);
SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable();
@ -4639,7 +4592,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentProd", "data", data);
SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable();
@ -4676,7 +4629,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data);
SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable();
@ -4715,7 +4668,7 @@ public class SDBaseOps {
* @return output Unsorted segment output (NUMERIC type)
*/
public SDVariable unsortedSegmentSum(String name, SDVariable data, SDVariable segmentIds,
int numSegments) {
int numSegments) {
SDValidation.validateNumerical("unsortedSegmentSum", "data", data);
SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds);
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable();
@ -4771,7 +4724,7 @@ public class SDBaseOps {
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
*/
public SDVariable variance(SDVariable x, boolean biasCorrected, boolean keepDims,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("variance", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
@ -4795,7 +4748,7 @@ public class SDBaseOps {
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
*/
public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims,
int... dimensions) {
int... dimensions) {
SDValidation.validateNumerical("variance", "x", x);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable();

View File

@ -1,22 +1,20 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
@ -250,7 +248,7 @@ public class SDBitwise extends SDOps {
/**
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
*
* @param x Input to be bit shifted (INT type)
@ -265,7 +263,7 @@ public class SDBitwise extends SDOps {
/**
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
*
* @param name name May be null. Name for the output variable
@ -348,7 +346,7 @@ public class SDBitwise extends SDOps {
/**
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
*
* @param x Input to be bit shifted (INT type)
@ -363,7 +361,7 @@ public class SDBitwise extends SDOps {
/**
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
*
* @param name name May be null. Name for the output variable

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -42,8 +42,7 @@ public class SDCNN extends SDOps {
/**
* 2D Convolution layer operation - average pooling 2d<br>
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying average pooling on the input (NUMERIC type)
*/
@ -56,8 +55,7 @@ public class SDCNN extends SDOps {
* 2D Convolution layer operation - average pooling 2d<br>
*
* @param name name May be null. Name for the output variable
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying average pooling on the input (NUMERIC type)
*/
@ -70,9 +68,7 @@ public class SDCNN extends SDOps {
/**
* 3D convolution layer operation - average pooling 3d <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output after applying average pooling on the input (NUMERIC type)
*/
@ -85,9 +81,7 @@ public class SDCNN extends SDOps {
* 3D convolution layer operation - average pooling 3d <br>
*
* @param name name May be null. Name for the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output after applying average pooling on the input (NUMERIC type)
*/
@ -302,9 +296,7 @@ public class SDCNN extends SDOps {
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv3DConfig Configuration Object
@ -322,9 +314,7 @@ public class SDCNN extends SDOps {
* Convolution 3D operation with optional bias <br>
*
* @param name name May be null. Name for the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv3DConfig Configuration Object
@ -342,9 +332,7 @@ public class SDCNN extends SDOps {
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param Conv3DConfig Configuration Object
* @return output Conv3d output variable (NUMERIC type)
@ -359,9 +347,7 @@ public class SDCNN extends SDOps {
* Convolution 3D operation with optional bias <br>
*
* @param name name May be null. Name for the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param Conv3DConfig Configuration Object
* @return output Conv3d output variable (NUMERIC type)
@ -377,8 +363,7 @@ public class SDCNN extends SDOps {
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param DeConv2DConfig Configuration Object
@ -396,8 +381,7 @@ public class SDCNN extends SDOps {
* 2D deconvolution operation with optional bias<br>
*
* @param name name May be null. Name for the output variable
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param DeConv2DConfig Configuration Object
@ -415,8 +399,7 @@ public class SDCNN extends SDOps {
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param DeConv2DConfig Configuration Object
* @return output result of deconv2d op (NUMERIC type)
@ -432,8 +415,7 @@ public class SDCNN extends SDOps {
* 2D deconvolution operation with optional bias<br>
*
* @param name name May be null. Name for the output variable
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param DeConv2DConfig Configuration Object
* @return output result of deconv2d op (NUMERIC type)
@ -519,8 +501,7 @@ public class SDCNN extends SDOps {
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4]<br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
@ -537,8 +518,7 @@ public class SDCNN extends SDOps {
* = [mb, 2, 4, 4]<br>
*
* @param name name May be null. Name for the output variable
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
@ -756,8 +736,7 @@ public class SDCNN extends SDOps {
/**
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
@ -769,8 +748,7 @@ public class SDCNN extends SDOps {
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param names names May be null. Arrays of names for the output variables.
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
@ -783,8 +761,7 @@ public class SDCNN extends SDOps {
/**
* 2D Convolution layer operation - max pooling 2d <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -797,8 +774,7 @@ public class SDCNN extends SDOps {
* 2D Convolution layer operation - max pooling 2d <br>
*
* @param name name May be null. Name for the output variable
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -811,9 +787,7 @@ public class SDCNN extends SDOps {
/**
* 3D convolution layer operation - max pooling 3d operation.<br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -826,9 +800,7 @@ public class SDCNN extends SDOps {
* 3D convolution layer operation - max pooling 3d operation.<br>
*
* @param name name May be null. Name for the output variable
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -841,8 +813,7 @@ public class SDCNN extends SDOps {
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
@ -862,8 +833,7 @@ public class SDCNN extends SDOps {
* Separable 2D convolution operation with optional bias <br>
*
* @param name name May be null. Name for the output variable
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
@ -883,8 +853,7 @@ public class SDCNN extends SDOps {
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param Conv2DConfig Configuration Object
@ -902,8 +871,7 @@ public class SDCNN extends SDOps {
* Separable 2D convolution operation with optional bias <br>
*
* @param name name May be null. Name for the output variable
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param Conv2DConfig Configuration Object
@ -964,8 +932,7 @@ public class SDCNN extends SDOps {
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4] <br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
@ -982,8 +949,7 @@ public class SDCNN extends SDOps {
* = [mb, 2, 4, 4] <br>
*
* @param name name May be null. Name for the output variable
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
@ -36,7 +36,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output loss variable (NUMERIC type)
*/
public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights,
@ -56,7 +56,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output loss variable (NUMERIC type)
*/
public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions,
@ -116,7 +116,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param dimension Dimension to perform the cosine distance over
* @return output Cosine distance loss (NUMERIC type)
*/
@ -141,7 +141,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param dimension Dimension to perform the cosine distance over
* @return output Cosine distance loss (NUMERIC type)
*/
@ -202,6 +202,49 @@ public class SDLoss extends SDOps {
return sd.updateVariableNameAndReference(out, name);
}
/**
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
* https://dl.acm.org/citation.cfm?id=1143891<br>
*
* @param targetLabels Label array (NUMERIC type)
* @param logitInput Inputs (NUMERIC type)
* @param targetLabelLengths Length of the target label (NUMERIC type)
* @param logitInputLengths Length of the input (NUMERIC type)
* @return output Ctc loss (NUMERIC type)
*/
public SDVariable ctcLoss(SDVariable targetLabels, SDVariable logitInput,
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
out.markAsLoss();
return out;
}
/**
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
* https://dl.acm.org/citation.cfm?id=1143891<br>
*
* @param name name May be null. Name for the output variable
* @param targetLabels Label array (NUMERIC type)
* @param logitInput Inputs (NUMERIC type)
* @param targetLabelLengths Length of the target label (NUMERIC type)
* @param logitInputLengths Length of the input (NUMERIC type)
* @return output Ctc loss (NUMERIC type)
*/
public SDVariable ctcLoss(String name, SDVariable targetLabels, SDVariable logitInput,
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
out.markAsLoss();
return sd.updateVariableNameAndReference(out, name);
}
/**
* Hinge loss: a loss function used for training classifiers.<br>
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
@ -210,7 +253,7 @@ public class SDLoss extends SDOps {
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights,
@ -232,7 +275,7 @@ public class SDLoss extends SDOps {
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions,
@ -297,7 +340,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param delta Loss function delta value
* @return output Huber loss (NUMERIC type)
*/
@ -324,7 +367,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param delta Loss function delta value
* @return output Huber loss (NUMERIC type)
*/
@ -423,7 +466,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param epsilon epsilon
* @return output Log loss (NUMERIC type)
*/
@ -445,7 +488,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param epsilon epsilon
* @return output Log loss (NUMERIC type)
*/
@ -499,7 +542,7 @@ public class SDLoss extends SDOps {
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
* @return output Loss variable (NUMERIC type)
*/
@ -521,7 +564,7 @@ public class SDLoss extends SDOps {
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
* @return output Loss variable (NUMERIC type)
*/
@ -585,7 +628,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable, scalar output (NUMERIC type)
*/
public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions,
@ -608,7 +651,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable, scalar output (NUMERIC type)
*/
public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions,
@ -666,13 +709,13 @@ public class SDLoss extends SDOps {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights,
@ -687,14 +730,14 @@ public class SDLoss extends SDOps {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param name name May be null. Name for the output variable
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions,
@ -709,7 +752,7 @@ public class SDLoss extends SDOps {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
@ -728,7 +771,7 @@ public class SDLoss extends SDOps {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param name name May be null. Name for the output variable
@ -764,7 +807,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictionLogits Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -796,7 +839,7 @@ public class SDLoss extends SDOps {
* @param label Label array (NUMERIC type)
* @param predictionLogits Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -872,7 +915,7 @@ public class SDLoss extends SDOps {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
@ -884,7 +927,7 @@ public class SDLoss extends SDOps {
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -901,7 +944,7 @@ public class SDLoss extends SDOps {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
@ -914,7 +957,7 @@ public class SDLoss extends SDOps {
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -932,7 +975,7 @@ public class SDLoss extends SDOps {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
@ -959,7 +1002,7 @@ public class SDLoss extends SDOps {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -144,22 +144,22 @@ public class SDRNN extends SDOps {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
@ -180,22 +180,22 @@ public class SDRNN extends SDOps {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param names names May be null. Arrays of names for the output variables.
@ -218,22 +218,22 @@ public class SDRNN extends SDOps {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
@ -248,22 +248,22 @@ public class SDRNN extends SDOps {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param names names May be null. Arrays of names for the output variables.

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.autodiff.samediff.ops;
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
import java.lang.String;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum CellAct {
TANH,

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Data format: "NCHW" or "NHWC" */
public enum DataFormat {
NCHW,

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum GateAct {
TANH,

View File

@ -1,32 +1,43 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */
public enum ImageResizeMethod {
ResizeBilinear, // as java require
ResizeNearest,
ResizeBilinear,
ResizeBicubic,
ResizeArea,
ResizeNearest,
ResizeGaussian,
ResizeLanczos3,
ResizeLanczos5,
ResizeMitchellcubic;
ResizeMitchelcubic,
ResizeArea
}

View File

@ -1,25 +1,28 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* for unidirectional: TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
* NST: shape [numExamples, inOutSize, timeLength]<br>
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br> for bidirectional:
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */
public enum LSTMDataFormat {
TNS,

View File

@ -1,25 +1,30 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* direction <br>
* FWD: 0 = fwd
* BWD: 1 = bwd
* BIDIR_SUM: 2 = bidirectional sum
* BIDIR_CONCAT: 3 = bidirectional concat
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
public enum LSTMDirectionMode {
FWD,

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* Activations */
public enum OutAct {
TANH,

View File

@ -1,22 +1,20 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;

View File

@ -1,22 +1,20 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;

View File

@ -1,25 +1,28 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.enums;
/**
* The data format of the input. Input shape depends on data format (in config):<br>
* TNS -> [timeSteps, batchSize, inSize]<br>
* NST -> [batchSize, inSize, timeSteps]<br>
* NTS -> [batchSize, timeSteps, inSize]<br> */
public enum RnnDataFormat {
TNS,

View File

@ -50,7 +50,7 @@ public abstract class BaseLoss extends DynamicCustomOp {
addArgs();
}
protected static INDArray getWeights(INDArray weights, INDArray predictions){
protected static INDArray getWeights(INDArray weights, INDArray predictions) {
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
}

View File

@ -20,29 +20,20 @@
package org.nd4j.linalg.api.ops.impl.loss;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
import java.util.List;
public class CtcLoss extends BaseLoss {
public class CtcLoss extends DynamicCustomOp {
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
super(sameDiff, lossReduce, predictions, weights, labels);
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
}
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
LossReduce lossReduce) {
this(sameDiff, lossReduce, predictions, weights, label);
}
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
super(lossReduce, predictions, weights, labels);
}
public CtcLoss(){ }
@ -52,9 +43,9 @@ public class CtcLoss extends BaseLoss {
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
public List<SDVariable> doDiff(List<SDVariable> grad) {
//No external gradient
//Args are: predictions, weights, label
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
return new CtcLossBp(sameDiff, arg(0), arg(1), arg(2),arg(3)).outputs();
}
}

View File

@ -20,17 +20,17 @@
package org.nd4j.linalg.api.ops.impl.loss.bp;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.List;
public class CtcLossBp extends BaseLossBp {
public class CtcLossBp extends DynamicCustomOp {
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
super(sameDiff, lossReduce, predictions, weights, labels);
public CtcLossBp(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
}
public CtcLossBp(){ }

View File

@ -1,22 +1,20 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;

View File

@ -1,22 +1,20 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
@ -134,7 +132,7 @@ public class NDBitwise {
/**
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
*
* @param x Input to be bit shifted (INT type)
@ -180,7 +178,7 @@ public class NDBitwise {
/**
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
*
* @param x Input to be bit shifted (INT type)

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.DataFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,8 +41,7 @@ public class NDCNN {
/**
* 2D Convolution layer operation - average pooling 2d<br>
*
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying average pooling on the input (NUMERIC type)
*/
@ -54,9 +53,7 @@ public class NDCNN {
/**
* 3D convolution layer operation - average pooling 3d <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output after applying average pooling on the input (NUMERIC type)
*/
@ -161,9 +158,7 @@ public class NDCNN {
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param Conv3DConfig Configuration Object
@ -180,9 +175,7 @@ public class NDCNN {
/**
* Convolution 3D operation with optional bias <br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
* @param Conv3DConfig Configuration Object
* @return output Conv3d output variable (NUMERIC type)
@ -196,8 +189,7 @@ public class NDCNN {
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
* @param DeConv2DConfig Configuration Object
@ -214,8 +206,7 @@ public class NDCNN {
/**
* 2D deconvolution operation with optional bias<br>
*
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
* @param DeConv2DConfig Configuration Object
* @return output result of deconv2d op (NUMERIC type)
@ -263,8 +254,7 @@ public class NDCNN {
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4]<br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)
@ -373,8 +363,7 @@ public class NDCNN {
/**
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
*/
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
@ -385,8 +374,7 @@ public class NDCNN {
/**
* 2D Convolution layer operation - max pooling 2d <br>
*
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param Pooling2DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -398,9 +386,7 @@ public class NDCNN {
/**
* 3D convolution layer operation - max pooling 3d operation.<br>
*
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
* @param Pooling3DConfig Configuration Object
* @return output Result after applying max pooling on the input (NUMERIC type)
*/
@ -412,8 +398,7 @@ public class NDCNN {
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
@ -432,8 +417,7 @@ public class NDCNN {
/**
* Separable 2D convolution operation with optional bias <br>
*
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
* @param Conv2DConfig Configuration Object
@ -471,8 +455,7 @@ public class NDCNN {
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
* = [mb, 2, 4, 4] <br>
*
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
* @param blockSize Block size, in the height/width dimension
* @param dataFormat Data format: "NCHW" or "NHWC"
* @return output Output variable (NUMERIC type)

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.ImageResizeMethod;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.NDValidation;
@ -35,7 +35,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output loss variable (NUMERIC type)
*/
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
@ -71,7 +71,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param dimension Dimension to perform the cosine distance over
* @return output Cosine distance loss (NUMERIC type)
*/
@ -104,6 +104,25 @@ public class NDLoss {
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
}
/**
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
* https://dl.acm.org/citation.cfm?id=1143891<br>
*
* @param targetLabels Label array (NUMERIC type)
* @param logitInput Inputs (NUMERIC type)
* @param targetLabelLengths Length of the target label (NUMERIC type)
* @param logitInputLengths Length of the input (NUMERIC type)
* @return output Ctc loss (NUMERIC type)
*/
public INDArray ctcLoss(INDArray targetLabels, INDArray logitInput, INDArray targetLabelLengths,
INDArray logitInputLengths) {
NDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
NDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
NDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(targetLabels, logitInput, targetLabelLengths, logitInputLengths))[0];
}
/**
* Hinge loss: a loss function used for training classifiers.<br>
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
@ -112,7 +131,7 @@ public class NDLoss {
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
@ -152,7 +171,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param delta Loss function delta value
* @return output Huber loss (NUMERIC type)
*/
@ -204,7 +223,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param epsilon epsilon
* @return output Log loss (NUMERIC type)
*/
@ -237,7 +256,7 @@ public class NDLoss {
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
* @return output Loss variable (NUMERIC type)
*/
@ -275,7 +294,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable, scalar output (NUMERIC type)
*/
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
@ -306,13 +325,13 @@ public class NDLoss {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
* @param predictions Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @return output Loss variable (NUMERIC type)
*/
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
@ -325,7 +344,7 @@ public class NDLoss {
/**
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
* this is the mean squared error loss function.<br>
*
* @param label Label array (NUMERIC type)
@ -357,7 +376,7 @@ public class NDLoss {
* @param label Label array (NUMERIC type)
* @param predictionLogits Predictions array (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -398,7 +417,7 @@ public class NDLoss {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>
@ -410,7 +429,7 @@ public class NDLoss {
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
* @param labelSmoothing Label smoothing value. Default value: 0
* @return output Loss variable (NUMERIC type)
*/
@ -425,7 +444,7 @@ public class NDLoss {
/**
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
* otherwise, the output is a scalar.<br>
* <p><br>
* When label smoothing is > 0, the following label smoothing is used:<br>

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PartitionMode;
import org.nd4j.linalg.api.buffer.DataType;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
@ -85,22 +85,22 @@ public class NDRNN {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
@ -121,22 +121,22 @@ public class NDRNN {
/**
* Long Short-Term Memory layer - Hochreiter 1997.<br>
* SUPPORTS following data formats:\n<br>
* for unidirectional: \n" +<br>
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
* SUPPORTS following data formats:<br>
* for unidirectional:<br>
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
* NST: shapes [numExamples, inOutSize, timeLength]<br>
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
* for bidirectional:\n<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
* SUPPORTS following direction modes:\n<br>
* for bidirectional:<br>
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
* SUPPORTS following direction modes:<br>
* FWD: forward<br>
* BWD: backward<br>
* BIDIR_SUM: bidirectional sum\n<br>
* BIDIR_CONCAT: bidirectional concat\n" +<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
* BIDIR_SUM: bidirectional sum<br>
* BIDIR_CONCAT: bidirectional concat<br>
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
* You may use different gate configurations:<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
*
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)

View File

@ -1,25 +1,25 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
package org.nd4j.linalg.factory.ops;
import static org.nd4j.linalg.factory.NDValidation.isSameType;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;