commit
0c81654567
|
@ -0,0 +1,614 @@
|
|||
# ND4J Op Definitions and Code Generation
|
||||
This project contains the ND4J Op definitions, the DSL (Domain Specific Language) that is used for those definitions and
|
||||
code generators that use those definitions to create the actual Java code that is used to use the defined operations.
|
||||
|
||||
|
||||
## Why define ops externally?
|
||||
As we started to support SameDiff, we also started to introduce inconsistencies between SameDiff and ND4J. Even though
|
||||
both of those libraries use the same underlying implementations for operations, there are both small and large
|
||||
differences in the API that we provide for them. Sometimes, we have provided an official API only for one usage, and not
|
||||
the other. And very often the documentation for a single op is in many different places.
|
||||
|
||||
In the future we want to support other programming languages with libnd4j, and provide more ways to use our C++ backend.
|
||||
This would only increase the aforementioned problems.
|
||||
|
||||
The root of all of those problems, is that Ops are used across different environments, and there is no single way of
|
||||
defining them with an enforced interface.
|
||||
|
||||
|
||||
## How does this project help with enforcing a single consistent interface for ops?
|
||||
The solution we propose, is to define the operations separately, and then generate the necessary API code for them. All
|
||||
of the generated code is to be considered untouchable, editing it will result in the changes being overwritten sooner
|
||||
rather than later.
|
||||
|
||||
The combination of external op definition and code generation, opens up many opportunities for us. The first one being
|
||||
that we can easily create consistent APIs for both ND4J and SameDiff in Java. But, looking into the future, we can also
|
||||
create those APIs for other programming languages like Python, Swift, or even C#. We can even go beyond programming
|
||||
languages, and use the op definitions to create better documentation than what JavaDoc or similar might support out of
|
||||
the box.
|
||||
|
||||
## Maintenance
|
||||
This project is currently maintained by Paul Dubs, with feedback often collected from raver119 and Alex Black.
|
||||
|
||||
## Current Status
|
||||
At the moment we still focus on nailing down an easily readable and contribution friendly DSL for op definition and code
|
||||
generation that can replace namespace definitions. This means that at the moment we still rely on the pre-existing Op
|
||||
definition classes that already exist in ND4J.
|
||||
|
||||
## Roadmap
|
||||
* Replace Bitwise and Random namespaces with autogenerated code – In progress.
|
||||
* Implement a convenient CLI tool.
|
||||
* Define all Ops using the DSL.
|
||||
* Automatically generate derivative op declarations from existing ops
|
||||
* Replace all namespace definitions in ND4J / SameDiff with automatically generated ones
|
||||
* Replace all Op classes with automatically generated ones.
|
||||
|
||||
# Usage
|
||||
Pre-requisites:
|
||||
* JDK 8 or higher
|
||||
* Maven 3.3 or higher
|
||||
|
||||
TODO: Show usage output of the project itself
|
||||
|
||||
TODO: Show how to use from mvn
|
||||
|
||||
|
||||
## Generating Code - ND4J Namespaces
|
||||
|
||||
A script - `generate.sh` - is provided in the project root. This can be used (at present) to generate ND4J namespace classes.
|
||||
It is assumed that the deeplearning4j mono repo and the dl4j-dev-tools repo both exist and have a common parent directory
|
||||
i.e., `somedir/deeplearning4j` and `somedir/dl4j-dev-tools` both exist.
|
||||
|
||||
The script takes as argument the name (or names) of the ND4J namespaces to generate (not case sensitive) and projects (supported
|
||||
projects are nd4j, sd and both by default).
|
||||
|
||||
As of 26/11, namespaces names (and hence valid args) include: `bitwise`, `neuralnetwork`, `random`, and `math`
|
||||
Note also that `all` may be passed to the script to generate all namespaces.
|
||||
|
||||
For example, to generate both bitwise and random namespaces for both nd4j and SameDiff:
|
||||
```
|
||||
./generate.sh bitwise,random
|
||||
```
|
||||
Or to generate all namespaces for both nd4j and SameDiff, use:
|
||||
```
|
||||
./generate.sh all
|
||||
```
|
||||
To generate namespaces for one project only, use:
|
||||
```
|
||||
./generate.sh linalg -projects sd
|
||||
```
|
||||
or:
|
||||
```
|
||||
./generate.sh linalg -projects nd4j
|
||||
```
|
||||
The script will first compile the project, before running.
|
||||
Internally, the `org.nd4j.codegen.cli.CLI` class is used.
|
||||
Classes are written to `deeplearning4j/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/`
|
||||
|
||||
## Generating documentation.
|
||||
It is possible to use generate.sh for generation of code only, docs in markdown format only, or both docs and code.
|
||||
To generate docs only and store them to new folder "docs" for all namespaces:
|
||||
```
|
||||
./generate.sh all -docsdir ../../docs
|
||||
```
|
||||
Generation for selected namespaces works in the same way as for code:
|
||||
```
|
||||
./generate.sh -docsdir ../../docs bitwise,linalg
|
||||
```
|
||||
|
||||
# Code structure
|
||||
The project is implemented using a mix of Java and Kotlin. The DSL definition and the accompanying data structures are
|
||||
implemented in Kotlin. At the moment the code generators are implemented in Java, in order to allow people who are not
|
||||
fluent in Kotlin, but know Java to be able to contribute to the code generators.
|
||||
|
||||
The source code for this project is structured a bit different that what you would typically see in a Java or Kotlin
|
||||
project. When you take a look inside the `src/main` directory, you will find 4 sub-directories.
|
||||
|
||||
The `java` and `kotlin` directories contain Java and Kotlin code respectively.
|
||||
|
||||
In order to not confuse op definitions with the machinery that allows them to be defined in that way, ops are kept in a
|
||||
separate folder called `ops`.
|
||||
|
||||
Because we use JavaPoet for Java code generator implementation, we also have yet another folder called `stubs`. That
|
||||
folder contains stub classes, that are used to reference other classes available in ND4J. These stub classes are
|
||||
intentionally left empty, as JavaPoet only requires them for naming and automatically creating proper imports. We use
|
||||
stub classes instead of depending on the actual nd4j API in order to break a cyclic dependency that would otherwise be
|
||||
created (i.e. in order to be able to generate code for ND4J, we would need an already compiled nd4j to be available).
|
||||
**Note:** If something is stubbed here and is moved in ND4J, then it also has to be moved to the appropriate place here,
|
||||
otherwise the generated code will be wrong.
|
||||
|
||||
The `adr` folder contains "Architecture Decision Records". These files give you more insight into the "why" of some of
|
||||
the bigger decisions within this project.
|
||||
|
||||
# DSL for Op Definition
|
||||
Ops are defined using a DSL that is implemented in Kotlin. This means that other than the DSL, as defined in the
|
||||
following, you can also use all of Kotlin when defining Ops. However, doing things the obvious and clearly
|
||||
understandable way is better than coming up with a clever way, so prefer to use the DSL as described if unsure.
|
||||
|
||||
```kotlin
|
||||
val mathNs = Namespace("math") {
|
||||
Op("add") {
|
||||
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic"
|
||||
|
||||
Input(NUMERIC, "x") { description = "First input to add" }
|
||||
Input(NUMERIC,"y") { count = AtLeast(1); description = "Second input to add" }
|
||||
Arg(INT,"shape") { count = AtLeast(1); description = "shape" }
|
||||
|
||||
|
||||
Output(NUMERIC, "z") { description = "Output (x+y)" }
|
||||
|
||||
Doc(Language.ANY, DocScope.ALL) {
|
||||
"""
|
||||
(From AddOp) Add op doc text that will appear everywhere - classes, constructors, op creators
|
||||
""".trimIndent()
|
||||
}
|
||||
Doc(Language.ANY, DocScope.CLASS_DOC_ONLY) {
|
||||
"Add op doc text that will appear in all class docs (javadoc etc)"
|
||||
}
|
||||
Doc(Language.ANY, DocScope.CONSTRUCTORS_ONLY) {
|
||||
"Add op doc text for constructors only"
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This example shows how a namespace is defined. Namespaces are at the top layer, and ops can only be defined within the
|
||||
context of a namespace. This example namespace contains only a single op, called "add". If we wanted to add another op,
|
||||
we would simply add it below the first.
|
||||
|
||||
As you can see, every op has to have a name, if you try to create one without a name, you will get a compile error.
|
||||
Within the context of the op, we first set in which java package the op class can be found in, then define its inputs,
|
||||
arguments and outputs and finally add some free form documentation about what that op is doing.
|
||||
|
||||
Like with the op itself, the inputs, arguments and outputs all have to have a name, but unlike the op, they also require
|
||||
a type. Within their context, you can set a description and a count of how many parameters they can take respectively.
|
||||
|
||||
If an input, argument or output take anything else than exactly 1, they will be treated as arrays. Typically you would
|
||||
use this to define ops like `concat` which can take multiple input tensors or ops that might take shape arguments.
|
||||
|
||||
## Examples
|
||||
The following shows how a typical op definition looks like and how the generated Java code may look.
|
||||
|
||||
An op might be defined like this:
|
||||
|
||||
```kotlin
|
||||
Op("binomial") {
|
||||
javaPackage = "org.nd4j.linalg.api.ops.random.custom"
|
||||
Arg(INT, "nTrials") { description = "Number of trials parameter for the binomial distribution" }
|
||||
Arg(FLOATING_POINT, "p") { description = "Probability of success for each trial" }
|
||||
Arg(INT, "shape") { count = AtLeast(1); description = "Shape of the new random SDVariable, as a 1D array" }
|
||||
|
||||
Output(NUMERIC, "output") { description = "new random SDVariable, where values are randomly sampled according to a Binomial distribution" }
|
||||
|
||||
Doc(Language.ANY, DocScope.ALL) {
|
||||
"""
|
||||
Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||
with the specified number of trials and probability.
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The java code generator will create a method like the following for it:
|
||||
```java
|
||||
/**
|
||||
* Generate a new random SDVariable, where values are randomly sampled according to a Binomial distribution,
|
||||
* with the specified number of trials and probability.
|
||||
*
|
||||
* @param nTrials Number of trials parameter for the binomial distribution
|
||||
* @param p Probability of success for each trial
|
||||
* @param shape Shape of the new random SDVariable, as a 1D array (Size: AtLeast(min=1))
|
||||
* @return output new random SDVariable, where values are randomly sampled according to a Binomial distribution (NUMERIC type)
|
||||
*/
|
||||
public static INDArray binomial(long nTrials, double p, long... shape) {
|
||||
Preconditions.checkArgument(shape.length >= 1, "shape has incorrect count. Expected: AtLeast(min=1)");
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.random.custom.BinomialOp(nTrials, p, shape))[0];
|
||||
}
|
||||
```
|
||||
|
||||
Or an op with some more constraints:
|
||||
|
||||
```kotlin
|
||||
Op("and") {
|
||||
javaPackage = "org.nd4j.linalg.api.ops.impl.transforms.custom"
|
||||
val x = Input(INT, "x") { description = "First input array" }
|
||||
val y = Input(INT, "y") { description = "Second input array" }
|
||||
Constraint("Must be same types"){ sameType(x, y) }
|
||||
Constraint("Must have broadcastable shapes"){ broadcastableShapes(x, y) }
|
||||
|
||||
Output(INT, "output"){ description = "Bitwise AND array" }
|
||||
|
||||
Doc(Language.ANY, DocScope.ALL){
|
||||
"""
|
||||
Bitwise AND operation. Supports broadcasting.
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
will be converted to java like this:
|
||||
|
||||
```java
|
||||
/**
|
||||
* Bitwise AND operation. Supports broadcasting.
|
||||
*
|
||||
* Inputs must satisfy the following constraints:
|
||||
* Must be same types: isSameType(x, y)
|
||||
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
|
||||
*
|
||||
* @param x First input array (INT type)
|
||||
* @param y Second input array (INT type)
|
||||
* @return output Bitwise AND array (INT type)
|
||||
*/
|
||||
public static INDArray and(INDArray x, INDArray y) {
|
||||
NDValidation.validateInteger("and", x);
|
||||
NDValidation.validateInteger("and", y);
|
||||
Preconditions.checkArgument(isSameType(x, y), "Must be same types");
|
||||
Preconditions.checkArgument(isBroadcastableShapes(x, y), "Must have broadcastable shapes");
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.AndOp(x, y))[0];
|
||||
}
|
||||
```
|
||||
|
||||
# Full DSL Description
|
||||
## Namespace
|
||||
|
||||
fun NamespaceName() = Namespace("name"){ /* Op definitions in namespace context */}
|
||||
|
||||
Defines a namespace.
|
||||
|
||||
## Op
|
||||
Only available within a namespace context
|
||||
|
||||
Op("opName") { /* op properties in op context */ }
|
||||
Op("anotherOp", mixin) { /* op properties in op context */ }
|
||||
Op("anotherOp2", mixin, keepInputs=false) { /* op properties in op context */ }
|
||||
|
||||
Every op requires a namespace unique op name.
|
||||
|
||||
When defining an op, you can also pass a mixin that it should inherit initial properties from. This has the same effect
|
||||
as using `useMixin(mixin)` as the very first thing in the op definition. If you don't want to inherit all of the
|
||||
parameters of the mixin, you can pass the same additional configuration as you would pass to
|
||||
`useMixin(mixin, ...options..)`. See [Mixin](#mixin) for more information.
|
||||
|
||||
### Op properties
|
||||
* `javaPackage` (String): Package where the op is to be found in the java implementation.
|
||||
* `javaOpClass` (String): Name of java op class if inconsistent with opName. Default: same as opName
|
||||
* `libnd4jName` (String): The name the op has in libnd4j. Default: same as opName
|
||||
|
||||
|
||||
## Mixin
|
||||
Available in global context.
|
||||
|
||||
val mixin = Mixin("name"){ /* op properties in op context */ }
|
||||
// within an op context:
|
||||
useMixin(mixin)
|
||||
useMixin(mixin, ...options...)
|
||||
|
||||
// When needing to access something from within the mixin
|
||||
mixin.input("name")
|
||||
mixin.arg("name")
|
||||
mixin.config("name")
|
||||
mixin.output("name")
|
||||
|
||||
Mixins provide the facility to share commonalities between Ops. You can think of it like inheritance, especially when
|
||||
you declare the use of a mixin on Op definition. In contrast to normal (single) inheritance where only a single super
|
||||
class is possible, the mixin mechanism allows to "inherit" from multiple sources.
|
||||
|
||||
You can define almost all the same things within a mixin that you can within an Op. The only things that *can not* be
|
||||
configured within a mixin are Op `name`, `libnd4jName` and `javaOpClass`.
|
||||
|
||||
As mixins can be configured within the global context, you can share them across namespaces by defining them in their
|
||||
own file. If a mixin is namespace specific, you can also define it within the namespace context.
|
||||
|
||||
Mixins are used either on definition as a parameter `Op("opname", mixin){...}`, or with `useMixin(mixin)` within the op
|
||||
definition. While the former version only supports a single mixin, the latter version allows you to use as many mixins
|
||||
as are required.
|
||||
|
||||
You can also build up mixins by using `useMixin(mixin)` inside a Mixin itself.
|
||||
|
||||
`useMixin(mixin, ...options...)` supports a few additional options: `keepInputs`, `keepArgs`, `keepConfigs`,
|
||||
`keepOutputs`, `keepSignatures`, `keepDoc`, `keepConstraints`. They default to `true`. If you want to skip including
|
||||
some of them, you simply set the parameter for it to `false`, e.g. `useMixin(mixin, keepDoc=false)`.
|
||||
|
||||
When using `useMixin(mixin)`, all definitions within the mixin are applied as if this invocation was replaced with the
|
||||
content of the mixin itself. This means, that if you have already defined anything prior to using a mixin, the mixin's
|
||||
definitions will be **after** the previously defined things. This can be very useful if the commonality between ops is
|
||||
that they have a few trailing options.
|
||||
|
||||
If a named property or section is defined in both a mixin (or multiple mixins) and the op, then the **last** to define it will
|
||||
win. Named properties are `legacy`, `javaPackage`, named sections are `Input`, `Arg`, `Output`, `Config`.
|
||||
|
||||
For example, assume you have `javaPackage` defined in both an op and a mixin. Then you can have the following two
|
||||
cases:
|
||||
|
||||
First case:
|
||||
```kotlin
|
||||
Op("foo"){
|
||||
useMixin(exampleMixin)
|
||||
javaPackage = "some.example.package"
|
||||
}
|
||||
```
|
||||
|
||||
Second case:
|
||||
```kotlin
|
||||
Op("foo"){
|
||||
javaPackage = "some.example.package"
|
||||
useMixin(exampleMixin)
|
||||
}
|
||||
```
|
||||
|
||||
In the first case, the op will have the `javaPackage` value that is defined within the op. In the second case it will
|
||||
have the `javaPackage` value defined in the mixin.
|
||||
|
||||
For inputs, args, outputs, it works similarly. Assume you have `Input(dataType, "a")` defined in both the mixin and the
|
||||
op. Again you can have two cases:
|
||||
|
||||
First case:
|
||||
```kotlin
|
||||
Op("foo"){
|
||||
useMixin(exampleMixin)
|
||||
Input(NUMERIC, "a")
|
||||
}
|
||||
```
|
||||
|
||||
Second case:
|
||||
```kotlin
|
||||
Op("foo"){
|
||||
Input(NUMERIC, "a")
|
||||
useMixin(exampleMixin)
|
||||
}
|
||||
```
|
||||
|
||||
In the first case, it will overwrite the input from the mixin. In the second case, the mixin will overwrite that the
|
||||
input from the op.
|
||||
|
||||
## Config
|
||||
Only available within a namespace context
|
||||
|
||||
val nameConfig = Config("Name"){
|
||||
/* input, arg, constraint, doc properties */
|
||||
}
|
||||
|
||||
Every config requires a namespace unique name.
|
||||
|
||||
A config allows to define a configuration class, that can be used as a holder for complex properties of specific ops
|
||||
which will be passed to an op as a parameter.
|
||||
|
||||
Similar to an op itself, it supports `Input`, `Arg`, `Constraint` and `Doc` definitions.
|
||||
|
||||
in order to use the config within an op you either use `useConfig(cfg)` or `val configRef = useConfig(cfg)`. The second
|
||||
form allows you to reference the config.
|
||||
|
||||
Referencing the config allows to you reference its inputs and args by name: `configRef.input("name")` and
|
||||
`configRef.arg("name")`. Also it allows you to use a config in a signature `Signature(a, b, c, configRef)`.
|
||||
|
||||
When default and shorthand signatures are used, configs will be always placed at the end.
|
||||
|
||||
If a config is defined but not used, an `IllegalStateException` will be thrown.
|
||||
|
||||
See also [ADR 0007 "Configuration Objects"](adr/0007-configuration_objects.md).
|
||||
|
||||
|
||||
## Input
|
||||
Available within an op, mixin and a config context
|
||||
|
||||
Input(FLOATING_POINT, "b"){ /* input properties in input context */ }
|
||||
val a = Input(INT, "a"){ /* input properties in input context */ }
|
||||
|
||||
Inputs represent tensors. They are what the op will work on.
|
||||
|
||||
Every input requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||
|
||||
When defining an input, you can assign it to a variable in order to be able to reference it later on. You might want to
|
||||
do this when defining constraints.
|
||||
|
||||
If you want an input to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
|
||||
that the count is meant to be `Exactly(1)`.
|
||||
|
||||
### Input properties
|
||||
* `description` (String): A short description what this input represents. Setting this is recommended.
|
||||
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
|
||||
* `defaultValue` (Input): use another input as the default if this isn't set explicitly. The data type of the other
|
||||
input has to match the data type of this input. The other input may also have a default value.
|
||||
|
||||
## Argument
|
||||
Available within an op, mixin and config context
|
||||
|
||||
Arg(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
|
||||
val a = Arg(INT, "a"){ /* Arg properties in arg context */ }
|
||||
|
||||
Args represent arguments. They modify how the op works on its inputs.
|
||||
|
||||
Every arg requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||
|
||||
When defining an arg, you can assign it to a variable in order to be able to reference it later on. You might want to do
|
||||
this when defining constraints.
|
||||
|
||||
If you want an arg to represent an array, you will have to set a count accordingly. If no count is set, it is assumed
|
||||
that the count is meant to be `Exactly(1)`.
|
||||
|
||||
Note (Java specific): If the last arg is defined to represent an array, it will be translated to a vararg parameter, e.g.
|
||||
`Arg(INT, "a"){ count = AtLeast(1); description = "..." }` will be turned into `long... a`.
|
||||
|
||||
### Argument properties
|
||||
* `description` (String): A short description what this argument represents. Setting this is recommended.
|
||||
* `count` (Count): Can take one of `Exactly(n)`; `AtLeast(n)`; `AtMost(n)`; `Range(from, to)`
|
||||
* `defaultValue` (null|Number|Boolean|int[]|double[]|boolean[]|Arg|TensorShapeValue|TensorDataTypeValue|String):
|
||||
Use given value as default value, if this isn't explicitly set. Can refer to *inputs* and *outputs* using `x.shape()`
|
||||
and `x.dataType()`. The given default values has to match the data type for this argument. May also refer to another
|
||||
Arg, and that Arg may also have a default value. Default values based on outputs are treated like without a default
|
||||
in SameDiff mode.
|
||||
* `possibleValues` (String[]): only available when ENUM data type is used for the argument. Takes a list of possible
|
||||
values for the Enum. If used in in abstract base op, the enum will only be created once. See also
|
||||
[ADR 0006 "Op specific enums"](adr/0006-op_specific_enums.md).
|
||||
|
||||
|
||||
## Output
|
||||
Only available within an op and mixin context
|
||||
|
||||
Output(FLOATING_POINT, "b"){ /* Arg properties in arg context */ }
|
||||
|
||||
Every output requires a data type (either `INT`, `FLOATING_POINT`, `NUMERIC` or `BOOLEAN`) and an op unique name.
|
||||
|
||||
While outputs can be assigned to a variable, there is no intended use-case for it. In contrast to inputs and args,
|
||||
outputs can not be used in constraints.
|
||||
|
||||
### Output properties
|
||||
* `description` (String): A short description what this argument represents. Setting this is recommended.
|
||||
|
||||
|
||||
## Signature
|
||||
Only available within an op and mixin context
|
||||
|
||||
Signature(a,b,c)
|
||||
Signature(a,b,c) { "Some Documentation" }
|
||||
AllParamSignature()
|
||||
AllDefaultParamSignature()
|
||||
|
||||
For some ops only specific signatures make sense, as for example some optional parameters may become required in the
|
||||
presence of other optional parameters. This feature is mainly meant to help with the fact that not all programming
|
||||
languages (e.g. Java) support default parameters. Each signature is meant to describe one overload in those languages.
|
||||
|
||||
See also [ADR 0005 "Optional parameters and signatures"](adr/0005-optional_parameters_and_signatures.md).
|
||||
|
||||
Signatures can also reference the output(s) of an op. Those signatures are only relevant in NDArray programming mode.
|
||||
They are not to be generated in SameDiff mode.
|
||||
|
||||
`AllParamSignature()` and `AllDefaultParamSignature()` are short hands for `Signature(...all parameters...)` and
|
||||
`Signature(...only parameters with no default values...)`. Their parameters include references to outputs unless
|
||||
disabled using `withOutput=false` (e.g. `AllParamSignature(withOutput=false)`).
|
||||
|
||||
If no signature is specified for an op, it is treated as if `AllParamSignature()` and `AllDefaultParamSignature()` are
|
||||
both specified.
|
||||
|
||||
Each signature must satisfy the condition, that all required parameters are listed there. If this condition is not
|
||||
satisfied, an `IllegalStateException` will be thrown on construction.
|
||||
|
||||
|
||||
## Documentation
|
||||
Only available within an op and mixin context
|
||||
|
||||
Doc(Language.ANY, DocScope.ALL){
|
||||
""" Some documentation
|
||||
It can be multiline. And indented.
|
||||
""".trimIndent()
|
||||
}
|
||||
|
||||
Documentation can be language specific, and can be set to be only given at specific places. The documentation itself is
|
||||
given as a string. Because Kotlin supports multiline strings along with proper indentation, we are using them directly
|
||||
here.
|
||||
|
||||
Note: At the moment we are only creating java code, so the documentation can use JavaDoc syntax.
|
||||
|
||||
You can have multiple Doc definitions; they are treated as additive.
|
||||
|
||||
Any instances of the following values will be replaced when generating code:
|
||||
|
||||
* `%OPNAME%` -> operation name ("Add", "Sub", etc)
|
||||
* `%LIBND4J_OPNAME%` -> libnd4j op name ("add", "sub", etc)
|
||||
* `%INPUT_TYPE%` -> input / output type depending on the generated api, i.e. `SDVariable` for SameDiff and `INDArray`
|
||||
for ND4J
|
||||
|
||||
See `DocTokens` class for more details.
|
||||
|
||||
## Constraints
|
||||
Available within an op, mixin and a config context.
|
||||
|
||||
Constraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
|
||||
BackendConstraint("Error Message if constraint isn't satisfied"){ /* constraint definition */ }
|
||||
|
||||
Many ops expect their inputs and arguments to satisfy some specific rules. Those rules can be expressed with the
|
||||
constraint system.
|
||||
|
||||
Constraints are to be enforced within the frontend language, while BackendConstraints are currently only to be used as
|
||||
a part of the documentation. They will be enforced within the C++ backend, so there is no point in double checking them.
|
||||
|
||||
There is a system in place to define even complex constraints for inputs and arguments.
|
||||
|
||||
In a constraint definition, you can reference inputs and arguments directly, if they are previously assigned to
|
||||
a variable using `val name = Input(...)`. Inside the Constraint block, you can use the following operations:
|
||||
|
||||
* `eq`: Compare equality (applicable to numbers and booleans), e.g. `x eq 7`, `x eq true`
|
||||
* `neq`: Compare inequality (applicable to numbers and booleans), e.g. `x neq 3`, `x neq true`
|
||||
* `lt`, `lte`: less than, less than equal (applicable to numbers), e.g. `x lt 3`, `x lte 4`
|
||||
* `gt`, `gte`: greater than, grater than equal (applicable to numbers), e.g. `x gt 5`, `x gte 6`
|
||||
* `and`: combine two comparisons where both have to be true, e.g. `(x eq 8) and (y lt 3)`
|
||||
* `or`: combine two comparisons where one has to be true, e.g. `(x eq 8) or (y eq true)`
|
||||
* `all`: combine N comparisons where all have to be true, e.g. `all(x eq 8, y lt 3, z eq true)`
|
||||
* `some`: combine N comparisons where at least one has to be true, e.g. `some(x eq 8, y lt 3, z eq true)`
|
||||
* `not`: negates a comparison, e.g. `not(x eq 3)`
|
||||
|
||||
In addition to those operations, you also get access to some more complex constraints:
|
||||
* `sameType(...)`: true if all given inputs are the same type, e.g. `sameType(x,y,z)`
|
||||
* `sameShape(...)`: true if all given inputs have the same shape, e.g. `sameShape(x,y,z)`
|
||||
* `broadcastableShapes(...)`: true if all given inputs have broadcast compatible shapes, e.g. `broadcastableShapes(x,y,z)`
|
||||
|
||||
Inputs also get some additional methods on them to define useful constraints:
|
||||
* `input.rank()`: Rank of the given input
|
||||
* `input.sizeAt(i)`: size of the given input at the i-th dimension
|
||||
* `input.isScalar()`: Short hand for `x.rank() == 1`
|
||||
|
||||
### Examples
|
||||
Some examples of constraints, and what they evaluate to. The example code contains a little bit of context.
|
||||
|
||||
```kotlin
|
||||
val x = Input(INT, "x") { description = "First input array" }
|
||||
val y = Input(INT, "y") { description = "Second input array" }
|
||||
Constraint("foo bar"){
|
||||
x.sizeAt(7) eq 7 and y.isScalar()
|
||||
}
|
||||
```
|
||||
|
||||
will evaluate to:
|
||||
```java
|
||||
Preconditions.checkArgument((x.sizeAt(7) == 7) && (y.rank() == 1), "foo bar");
|
||||
```
|
||||
|
||||
More examples (only the constraint itself, without context code):
|
||||
|
||||
#### Some
|
||||
```kotlin
|
||||
some(input.rank() eq 3, input.sizeAt(2) gte 7, input.sizeAt(4) lt 5)
|
||||
```
|
||||
turns to:
|
||||
```java
|
||||
((x.rank() == 3) || (x.sizeAt(2) >= 7)) || (x.sizeAt(4) < 5)
|
||||
```
|
||||
|
||||
# Contributing to this project
|
||||
If you want to contribute to this project other than by adding or improving op definitions, the following sections might
|
||||
be of special interest to you.
|
||||
|
||||
## Extending the DSL
|
||||
The DSL is implemented using Kotlin’s type-safe builders feature
|
||||
(see https://kotlinlang.org/docs/reference/type-safe-builders.html). The basic principle is that functions calls can
|
||||
receive blocks that can be executed in a specified context. When combined with the fact that we are just looking to
|
||||
create an object graph that is then going to be used as input to the code generators, this allows us to create a very
|
||||
feature rich DSL without actually having to write a lot of code to support it.
|
||||
|
||||
Most of the DSL specific code can be found in `src/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt`. The actual class
|
||||
definitions for the object graph we are building, can be found in `src/kotlin/org/nd4j/codegen/api`.
|
||||
|
||||
If you want to add just a simple field to one of the objects, it is usually enough to directly add it to the particular
|
||||
class.
|
||||
|
||||
If you want to add a specific section to the op definition, i.e. a section like Input or Doc, you will have to add both
|
||||
the class for the object that it is going to be creating, as well as a function within OpBuilder.kt to create and
|
||||
register that section within the op.
|
||||
|
||||
**Note:** When you extend the DSL you will most likely also have to update all code generators to support the feature
|
||||
you have added.
|
||||
|
||||
## Adding / extending code generators
|
||||
Code generators can be written in either Java or Kotlin. Java has the advantage that more people will have experience in
|
||||
using it. Kotlin has the advantage of more convenient syntax, especially for plain string manipulation and when dealing
|
||||
with Enums and fixed sets of subclasses (called sealed classes in Kotlin).
|
||||
|
||||
All generators have to implement the `org.nd4j.codegen.api.generator.Generator` interface. For automatic detection by
|
||||
the CLI tool, they should also be within the `org.nd4j.codegen.impl.LANGUAGE` package, where `LANGUAGE` is the actual
|
||||
language that they generate.
|
||||
|
||||
Code generators can also use an auxiliary generator for constraint generation. Those auxiliary generators, have to
|
||||
implement ` org.nd4j.codegen.api.generator.ConstraintCodeGenerator` interface.
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
if test "$#" -eq 0; then
|
||||
echo "No namespaces were specified. One or more namespaces must be provided as an argument"
|
||||
echo "Usage example 1 (single namespace): ./generate.sh math"
|
||||
echo "Usage example 2 (multiple namespaces): ./generate.sh math,random"
|
||||
echo "Usage example 2 (all namespaces): ./generate.sh all"
|
||||
else
|
||||
mvn clean package -DskipTests
|
||||
java -cp target/codegen-1.0.0-SNAPSHOT-shaded.jar org.nd4j.codegen.cli.CLI -dir ../../ -namespaces "$@"
|
||||
fi
|
|
@ -0,0 +1,258 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>codegen</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<commonsio.version>2.5</commonsio.version>
|
||||
<commons.lang.version>3.12.0</commons.lang.version>
|
||||
<commons.dbutils.version>1.7</commons.dbutils.version>
|
||||
<lombok.version>1.18.8</lombok.version>
|
||||
<logback.version>1.1.7</logback.version>
|
||||
<junit.version>4.12</junit.version>
|
||||
<junit-jupiter.version>5.4.2</junit-jupiter.version>
|
||||
<java.version>1.8</java.version>
|
||||
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>
|
||||
<kotlin.version>1.3.50</kotlin.version>
|
||||
<kotlin.compiler.jvmTarget>1.8</kotlin.compiler.jvmTarget>
|
||||
<kotlin.compiler.incremental>true</kotlin.compiler.incremental>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.7.28</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>${commonsio.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>${lombok.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
<version>${commons.lang.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.squareup</groupId>
|
||||
<artifactId>javapoet</artifactId>
|
||||
<version>1.11.1</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<!-- Test Dependencies -->
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-api</artifactId>
|
||||
<version>${junit-jupiter.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
<version>${junit-jupiter.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-stdlib-jdk8</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-test</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.module</groupId>
|
||||
<artifactId>jackson-module-kotlin</artifactId>
|
||||
<version>2.9.9</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
<version>1.78</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<artifactId>build-helper-maven-plugin</artifactId>
|
||||
<version>3.0.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>add-source</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals><goal>add-source</goal></goals>
|
||||
<configuration>
|
||||
<sources>
|
||||
<source>src/main/stubs</source>
|
||||
</sources>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-shade-plugin</artifactId>
|
||||
<version>${maven-shade-plugin.version}</version>
|
||||
<configuration>
|
||||
<shadedArtifactAttached>true</shadedArtifactAttached>
|
||||
<createDependencyReducedPom>false</createDependencyReducedPom>
|
||||
<filters>
|
||||
<filter>
|
||||
<artifact>*:*</artifact>
|
||||
<excludes>
|
||||
<exclude>org/datanucleus/**</exclude>
|
||||
<exclude>META-INF/*.SF</exclude>
|
||||
<exclude>META-INF/*.DSA</exclude>
|
||||
<exclude>META-INF/*.RSA</exclude>
|
||||
</excludes>
|
||||
</filter>
|
||||
</filters>
|
||||
|
||||
</configuration>
|
||||
|
||||
<executions>
|
||||
<execution>
|
||||
<phase>package</phase>
|
||||
<goals>
|
||||
<goal>shade</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<transformers>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
||||
<resource>reference.conf</resource>
|
||||
</transformer>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
|
||||
</transformers>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<plugin>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-plugin</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
<configuration>
|
||||
<args>
|
||||
<arg>-Xjsr305=strict</arg>
|
||||
</args>
|
||||
<compilerPlugins>
|
||||
<plugin>spring</plugin>
|
||||
<plugin>jpa</plugin>
|
||||
</compilerPlugins>
|
||||
</configuration>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-allopen</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jetbrains.kotlin</groupId>
|
||||
<artifactId>kotlin-maven-noarg</artifactId>
|
||||
<version>${kotlin.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>compile</id>
|
||||
<goals> <goal>compile</goal> </goals>
|
||||
<configuration>
|
||||
<sourceDirs>
|
||||
<sourceDir>${project.basedir}/src/main/stubs</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/kotlin</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/java</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/main/ops</sourceDir>
|
||||
</sourceDirs>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>test-compile</id>
|
||||
<goals> <goal>test-compile</goal> </goals>
|
||||
<configuration>
|
||||
<sourceDirs>
|
||||
<sourceDir>${project.basedir}/src/test/stubs</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/kotlin</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/java</sourceDir>
|
||||
<sourceDir>${project.basedir}/src/test/ops</sourceDir>
|
||||
</sourceDirs>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
|
||||
<!-- https://kotlinlang.org/docs/reference/using-maven.html -->
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.5.1</version>
|
||||
<executions>
|
||||
<!-- Replacing default-compile as it is treated specially by maven -->
|
||||
<execution>
|
||||
<id>default-compile</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
<!-- Replacing default-testCompile as it is treated specially by maven -->
|
||||
<execution>
|
||||
<id>default-testCompile</id>
|
||||
<phase>none</phase>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>java-compile</id>
|
||||
<phase>compile</phase>
|
||||
<goals> <goal>compile</goal> </goals>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>java-test-compile</id>
|
||||
<phase>test-compile</phase>
|
||||
<goals> <goal>testCompile</goal> </goals>
|
||||
</execution>
|
||||
</executions>
|
||||
<configuration>
|
||||
<source>${java.version}</source>
|
||||
<target>${java.version}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
</project>
|
|
@ -19,17 +19,13 @@
|
|||
*/
|
||||
|
||||
package org.nd4j.codegen.impl.java;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.nd4j.codegen.api.Language;
|
||||
import org.nd4j.codegen.api.Namespace;
|
||||
import org.nd4j.codegen.api.NamespaceOps;
|
||||
import org.nd4j.codegen.api.Op;
|
||||
import org.nd4j.codegen.api.generator.Generator;
|
||||
import org.nd4j.codegen.api.generator.GeneratorConfig;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
public class JavaPoetGenerator implements Generator {
|
||||
|
||||
|
@ -40,12 +36,12 @@ public class JavaPoetGenerator implements Generator {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
|
||||
public void generateNamespaceNd4j(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
|
||||
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.linalg.factory", StringUtils.EMPTY);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws IOException {
|
||||
public void generateNamespaceSameDiff(NamespaceOps namespace, GeneratorConfig config, File directory, String className) throws java.io.IOException {
|
||||
//throw new UnsupportedOperationException("Not yet implemented");
|
||||
Nd4jNamespaceGenerator.generate(namespace, config, directory, className, "org.nd4j.autodiff.samediff", StringUtils.EMPTY);
|
||||
}
|
||||
|
|
|
@ -47,6 +47,23 @@ fun SDLoss() = Namespace("Loss"){
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
Op("ctcLoss") {
|
||||
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
|
||||
javaOpClass = "CtcLoss"
|
||||
Input(NUMERIC, "targetLabels") { description = "Label array" }
|
||||
Input(NUMERIC, "logitInput") { description = "Inputs" }
|
||||
Input(NUMERIC, "targetLabelLengths") { description = "Length of the target label" }
|
||||
Input(NUMERIC, "logitInputLengths") { description = "Length of the input"}
|
||||
Output(NUMERIC, "output"){ description = "Ctc loss " }
|
||||
Doc(Language.ANY, DocScope.ALL){
|
||||
"""
|
||||
CTC Loss: Connectionist Temporal Classification Loss. See:
|
||||
https://dl.acm.org/citation.cfm?id=1143891
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
|
||||
Op("cosineDistance") {
|
||||
javaPackage = "org.nd4j.linalg.api.ops.impl.loss"
|
||||
javaOpClass = "CosineDistanceLoss"
|
||||
|
|
|
@ -1,32 +1,26 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
|
@ -307,7 +301,7 @@ public class SDBaseOps {
|
|||
* @param transposeB Whether to transpose B arrays or not
|
||||
*/
|
||||
public SDVariable[] batchMmul(SDVariable[] inputsA, SDVariable[] inputsB, boolean transposeA,
|
||||
boolean transposeB) {
|
||||
boolean transposeB) {
|
||||
SDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
|
||||
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
|
||||
SDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
|
||||
|
@ -331,7 +325,7 @@ public class SDBaseOps {
|
|||
* @param transposeB Whether to transpose B arrays or not
|
||||
*/
|
||||
public SDVariable[] batchMmul(String[] names, SDVariable[] inputsA, SDVariable[] inputsB,
|
||||
boolean transposeA, boolean transposeB) {
|
||||
boolean transposeA, boolean transposeB) {
|
||||
SDValidation.validateNumerical("batchMmul", "inputsA", inputsA);
|
||||
Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length);
|
||||
SDValidation.validateNumerical("batchMmul", "inputsB", inputsB);
|
||||
|
@ -482,7 +476,7 @@ public class SDBaseOps {
|
|||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse,
|
||||
int... axis) {
|
||||
int... axis) {
|
||||
SDValidation.validateNumerical("cumprod", "in", in);
|
||||
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(sd,in, exclusive, reverse, axis).outputVariable();
|
||||
|
@ -563,7 +557,7 @@ public class SDBaseOps {
|
|||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse,
|
||||
int... axis) {
|
||||
int... axis) {
|
||||
SDValidation.validateNumerical("cumsum", "in", in);
|
||||
Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(sd,in, exclusive, reverse, axis).outputVariable();
|
||||
|
@ -680,7 +674,7 @@ public class SDBaseOps {
|
|||
* @param numPartitions Number of partitions, >= 1
|
||||
*/
|
||||
public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions,
|
||||
int numPartitions) {
|
||||
int numPartitions) {
|
||||
SDValidation.validateNumerical("dynamicPartition", "x", x);
|
||||
SDValidation.validateInteger("dynamicPartition", "partitions", partitions);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(sd,x, partitions, numPartitions).outputVariables();
|
||||
|
@ -1189,7 +1183,7 @@ public class SDBaseOps {
|
|||
* @return output INDArray with linearly spaced elements (NUMERIC type)
|
||||
*/
|
||||
public SDVariable linspace(String name, DataType dataType, double start, double stop,
|
||||
long number) {
|
||||
long number) {
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Linspace(sd,dataType, start, stop, number).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
@ -1205,7 +1199,7 @@ public class SDBaseOps {
|
|||
* @return output INDArray with linearly spaced elements (NUMERIC type)
|
||||
*/
|
||||
public SDVariable linspace(SDVariable start, SDVariable stop, SDVariable number,
|
||||
DataType dataType) {
|
||||
DataType dataType) {
|
||||
SDValidation.validateNumerical("linspace", "start", start);
|
||||
SDValidation.validateNumerical("linspace", "stop", stop);
|
||||
SDValidation.validateInteger("linspace", "number", number);
|
||||
|
@ -1224,7 +1218,7 @@ public class SDBaseOps {
|
|||
* @return output INDArray with linearly spaced elements (NUMERIC type)
|
||||
*/
|
||||
public SDVariable linspace(String name, SDVariable start, SDVariable stop, SDVariable number,
|
||||
DataType dataType) {
|
||||
DataType dataType) {
|
||||
SDValidation.validateNumerical("linspace", "start", start);
|
||||
SDValidation.validateNumerical("linspace", "stop", stop);
|
||||
SDValidation.validateInteger("linspace", "number", number);
|
||||
|
@ -1445,7 +1439,7 @@ public class SDBaseOps {
|
|||
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
|
||||
*/
|
||||
public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDim,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("matchConditionCount", "in", in);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
return new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable();
|
||||
|
@ -1469,7 +1463,7 @@ public class SDBaseOps {
|
|||
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
|
||||
*/
|
||||
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition,
|
||||
boolean keepDim, int... dimensions) {
|
||||
boolean keepDim, int... dimensions) {
|
||||
SDValidation.validateNumerical("matchConditionCount", "in", in);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, keepDim, dimensions).outputVariable();
|
||||
|
@ -1514,7 +1508,7 @@ public class SDBaseOps {
|
|||
* @return output Number of elements that the condition is satisfied for (NUMERIC type)
|
||||
*/
|
||||
public SDVariable matchConditionCount(String name, SDVariable in, Condition condition,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("matchConditionCount", "in", in);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(sd,in, condition, false, dimensions).outputVariable();
|
||||
|
@ -1895,7 +1889,7 @@ public class SDBaseOps {
|
|||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(SDVariable x, SDVariable y, boolean transposeX, boolean transposeY,
|
||||
boolean transposeZ) {
|
||||
boolean transposeZ) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
return new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
|
||||
|
@ -1914,7 +1908,7 @@ public class SDBaseOps {
|
|||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable mmul(String name, SDVariable x, SDVariable y, boolean transposeX,
|
||||
boolean transposeY, boolean transposeZ) {
|
||||
boolean transposeY, boolean transposeZ) {
|
||||
SDValidation.validateNumerical("mmul", "x", x);
|
||||
SDValidation.validateNumerical("mmul", "y", y);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.Mmul(sd,x, y, transposeX, transposeY, transposeZ).outputVariable();
|
||||
|
@ -2304,14 +2298,14 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
|
||||
* @param depth Number of classes
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param dataType Output data type
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off,
|
||||
DataType dataType) {
|
||||
DataType dataType) {
|
||||
SDValidation.validateNumerical("oneHot", "indices", indices);
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable();
|
||||
}
|
||||
|
@ -2324,14 +2318,14 @@ public class SDBaseOps {
|
|||
* @param name name May be null. Name for the output variable
|
||||
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
|
||||
* @param depth Number of classes
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param dataType Output data type
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on,
|
||||
double off, DataType dataType) {
|
||||
double off, DataType dataType) {
|
||||
SDValidation.validateNumerical("oneHot", "indices", indices);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, dataType).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
|
@ -2344,9 +2338,9 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
|
||||
* @param depth Number of classes
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off) {
|
||||
|
@ -2362,13 +2356,13 @@ public class SDBaseOps {
|
|||
* @param name name May be null. Name for the output variable
|
||||
* @param indices Indices - value 0 to depth-1 (NUMERIC type)
|
||||
* @param depth Number of classes
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @param axis
|
||||
* @param on
|
||||
* @param off
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on,
|
||||
double off) {
|
||||
double off) {
|
||||
SDValidation.validateNumerical("oneHot", "indices", indices);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.OneHot(sd,indices, depth, axis, on, off, DataType.FLOAT).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
|
@ -2436,7 +2430,7 @@ public class SDBaseOps {
|
|||
* As per onesLike(String, SDVariable) but the output datatype may be specified<br>
|
||||
*
|
||||
* @param input (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable onesLike(SDVariable input, DataType dataType) {
|
||||
|
@ -2449,7 +2443,7 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
|
||||
|
@ -2612,7 +2606,7 @@ public class SDBaseOps {
|
|||
* @param from Initial/smallest value
|
||||
* @param to Largest value (exclusive)
|
||||
* @param step Step size
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output INDArray with the specified values (NUMERIC type)
|
||||
*/
|
||||
public SDVariable range(double from, double to, double step, DataType dataType) {
|
||||
|
@ -2628,7 +2622,7 @@ public class SDBaseOps {
|
|||
* @param from Initial/smallest value
|
||||
* @param to Largest value (exclusive)
|
||||
* @param step Step size
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output INDArray with the specified values (NUMERIC type)
|
||||
*/
|
||||
public SDVariable range(String name, double from, double to, double step, DataType dataType) {
|
||||
|
@ -2644,7 +2638,7 @@ public class SDBaseOps {
|
|||
* @param from Initial/smallest value (NUMERIC type)
|
||||
* @param to Largest value (exclusive) (NUMERIC type)
|
||||
* @param step Step size (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output INDArray with the specified values (NUMERIC type)
|
||||
*/
|
||||
public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) {
|
||||
|
@ -2663,11 +2657,11 @@ public class SDBaseOps {
|
|||
* @param from Initial/smallest value (NUMERIC type)
|
||||
* @param to Largest value (exclusive) (NUMERIC type)
|
||||
* @param step Step size (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output INDArray with the specified values (NUMERIC type)
|
||||
*/
|
||||
public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step,
|
||||
DataType dataType) {
|
||||
DataType dataType) {
|
||||
SDValidation.validateNumerical("range", "from", from);
|
||||
SDValidation.validateNumerical("range", "to", to);
|
||||
SDValidation.validateNumerical("range", "step", step);
|
||||
|
@ -2727,7 +2721,7 @@ public class SDBaseOps {
|
|||
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
|
||||
*/
|
||||
public SDVariable replaceWhere(String name, SDVariable update, SDVariable from,
|
||||
Condition condition) {
|
||||
Condition condition) {
|
||||
SDValidation.validateNumerical("replaceWhere", "update", update);
|
||||
SDValidation.validateNumerical("replaceWhere", "from", from);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(sd,update, from, condition).outputVariable();
|
||||
|
@ -2761,7 +2755,7 @@ public class SDBaseOps {
|
|||
* @return output New array with values replaced where condition is satisfied (NUMERIC type)
|
||||
*/
|
||||
public SDVariable replaceWhere(String name, SDVariable update, double value,
|
||||
Condition condition) {
|
||||
Condition condition) {
|
||||
SDValidation.validateNumerical("replaceWhere", "update", update);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(sd,update, value, condition).outputVariable();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
|
@ -2799,47 +2793,6 @@ public class SDBaseOps {
|
|||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Split the input in to a list of sub tensors
|
||||
* @param input the input to split
|
||||
* @param numSizeSplits the number of splits
|
||||
* @param splitDim the dimension to split along
|
||||
* @return the set of output variables
|
||||
*/
|
||||
public SDVariable[] split(SDVariable input,int numSizeSplits,int splitDim) {
|
||||
SDValidation.validateNumerical("split",input);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Split the input in to a list of sub tensors
|
||||
* @param name the potential name of the input
|
||||
* @param input the input to split
|
||||
* @param numSizeSplits the number of splits
|
||||
* @param splitDim the dimension to split along
|
||||
* @return the set of output variables
|
||||
*/
|
||||
public SDVariable[] split(String name,SDVariable input,int numSizeSplits,int splitDim) {
|
||||
SDValidation.validateNumerical("split",input);
|
||||
SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.Split(sd,input,numSizeSplits,splitDim).outputVariables();
|
||||
SDVariable[] ret = new SDVariable[out.length];
|
||||
AtomicInteger index = new AtomicInteger(0);
|
||||
Arrays.stream(out).forEach(output -> {
|
||||
if(index.get() < 1) {
|
||||
ret[index.get()] = sd.updateVariableNameAndReference(output,name);
|
||||
index.incrementAndGet();
|
||||
}
|
||||
else {
|
||||
ret[index.get()] = sd.updateVariableNameAndReference(output,name + ":" + index.get());
|
||||
index.incrementAndGet();
|
||||
}
|
||||
});
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the<br>
|
||||
* input, but with the specified shape.<br>
|
||||
|
@ -2930,7 +2883,7 @@ public class SDBaseOps {
|
|||
* @return output Reversed sequences (NUMERIC type)
|
||||
*/
|
||||
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim,
|
||||
int batchDim) {
|
||||
int batchDim) {
|
||||
SDValidation.validateNumerical("reverseSequence", "x", x);
|
||||
SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
|
||||
return new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable();
|
||||
|
@ -2947,7 +2900,7 @@ public class SDBaseOps {
|
|||
* @return output Reversed sequences (NUMERIC type)
|
||||
*/
|
||||
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim,
|
||||
int batchDim) {
|
||||
int batchDim) {
|
||||
SDValidation.validateNumerical("reverseSequence", "x", x);
|
||||
SDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(sd,x, seq_lengths, seqDim, batchDim).outputVariable();
|
||||
|
@ -3123,7 +3076,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterAdd", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterAdd", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterAdd", "updates", updates);
|
||||
|
@ -3166,7 +3119,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterDiv", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterDiv", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterDiv", "updates", updates);
|
||||
|
@ -3209,7 +3162,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterMax(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterMax", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterMax", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterMax", "updates", updates);
|
||||
|
@ -3252,7 +3205,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterMin(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterMin", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterMin", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterMin", "updates", updates);
|
||||
|
@ -3295,7 +3248,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterMul", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterMul", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterMul", "updates", updates);
|
||||
|
@ -3338,7 +3291,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterSub", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterSub", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterSub", "updates", updates);
|
||||
|
@ -3381,7 +3334,7 @@ public class SDBaseOps {
|
|||
* @return output The updated variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable scatterUpdate(String name, SDVariable ref, SDVariable indices,
|
||||
SDVariable updates) {
|
||||
SDVariable updates) {
|
||||
SDValidation.validateNumerical("scatterUpdate", "ref", ref);
|
||||
SDValidation.validateNumerical("scatterUpdate", "indices", indices);
|
||||
SDValidation.validateNumerical("scatterUpdate", "updates", updates);
|
||||
|
@ -3595,7 +3548,7 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param lengths Lengths of the sequences (NUMERIC type)
|
||||
* @param maxLen Maximum sequence length
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) {
|
||||
|
@ -3610,7 +3563,7 @@ public class SDBaseOps {
|
|||
* @param name name May be null. Name for the output variable
|
||||
* @param lengths Lengths of the sequences (NUMERIC type)
|
||||
* @param maxLen Maximum sequence length
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen, DataType dataType) {
|
||||
|
@ -3625,7 +3578,7 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param lengths Lengths of the sequences (NUMERIC type)
|
||||
* @param maxLen Maximum sequence length (INT type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) {
|
||||
|
@ -3641,11 +3594,11 @@ public class SDBaseOps {
|
|||
* @param name name May be null. Name for the output variable
|
||||
* @param lengths Lengths of the sequences (NUMERIC type)
|
||||
* @param maxLen Maximum sequence length (INT type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen,
|
||||
DataType dataType) {
|
||||
DataType dataType) {
|
||||
SDValidation.validateNumerical("sequenceMask", "lengths", lengths);
|
||||
SDValidation.validateInteger("sequenceMask", "maxLen", maxLen);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(sd,lengths, maxLen, dataType).outputVariable();
|
||||
|
@ -3656,7 +3609,7 @@ public class SDBaseOps {
|
|||
* see sequenceMask(String, SDVariable, SDVariable, DataType)<br>
|
||||
*
|
||||
* @param lengths (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(SDVariable lengths, DataType dataType) {
|
||||
|
@ -3669,7 +3622,7 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param lengths (NUMERIC type)
|
||||
* @param dataType
|
||||
* @param dataType
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable sequenceMask(String name, SDVariable lengths, DataType dataType) {
|
||||
|
@ -3857,7 +3810,7 @@ public class SDBaseOps {
|
|||
* keepDims = false: [a,c]<br>
|
||||
*
|
||||
* @param x (NUMERIC type)
|
||||
* @param keepDims
|
||||
* @param keepDims
|
||||
* @param dimensions (Size: AtLeast(min=0))
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
|
@ -3879,7 +3832,7 @@ public class SDBaseOps {
|
|||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x (NUMERIC type)
|
||||
* @param keepDims
|
||||
* @param keepDims
|
||||
* @param dimensions (Size: AtLeast(min=0))
|
||||
* @return output (NUMERIC type)
|
||||
*/
|
||||
|
@ -4015,7 +3968,7 @@ public class SDBaseOps {
|
|||
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
|
||||
*/
|
||||
public SDVariable standardDeviation(SDVariable x, boolean biasCorrected, boolean keepDims,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("standardDeviation", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
return new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
|
||||
|
@ -4039,7 +3992,7 @@ public class SDBaseOps {
|
|||
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
|
||||
*/
|
||||
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected,
|
||||
boolean keepDims, int... dimensions) {
|
||||
boolean keepDims, int... dimensions) {
|
||||
SDValidation.validateNumerical("standardDeviation", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
|
||||
|
@ -4084,7 +4037,7 @@ public class SDBaseOps {
|
|||
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
|
||||
*/
|
||||
public SDVariable standardDeviation(String name, SDVariable x, boolean biasCorrected,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("standardDeviation", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(sd,x, biasCorrected, false, dimensions).outputVariable();
|
||||
|
@ -4113,7 +4066,7 @@ public class SDBaseOps {
|
|||
* @return output A subset of the input array (NUMERIC type)
|
||||
*/
|
||||
public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides,
|
||||
int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
||||
int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) {
|
||||
SDValidation.validateNumerical("stridedSlice", "in", in);
|
||||
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
||||
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
|
||||
|
@ -4144,8 +4097,8 @@ public class SDBaseOps {
|
|||
* @return output A subset of the input array (NUMERIC type)
|
||||
*/
|
||||
public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end,
|
||||
long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask,
|
||||
int shrinkAxisMask) {
|
||||
long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask,
|
||||
int shrinkAxisMask) {
|
||||
SDValidation.validateNumerical("stridedSlice", "in", in);
|
||||
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
||||
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
|
||||
|
@ -4196,7 +4149,7 @@ public class SDBaseOps {
|
|||
* @return output A subset of the input array (NUMERIC type)
|
||||
*/
|
||||
public SDVariable stridedSlice(String name, SDVariable in, long[] begin, long[] end,
|
||||
long... strides) {
|
||||
long... strides) {
|
||||
SDValidation.validateNumerical("stridedSlice", "in", in);
|
||||
Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length);
|
||||
Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length);
|
||||
|
@ -4330,7 +4283,7 @@ public class SDBaseOps {
|
|||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable tensorMmul(SDVariable x, SDVariable y, int[] dimensionsX, int[] dimensionsY,
|
||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
SDValidation.validateNumerical("tensorMmul", "x", x);
|
||||
SDValidation.validateNumerical("tensorMmul", "y", y);
|
||||
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
|
||||
|
@ -4352,7 +4305,7 @@ public class SDBaseOps {
|
|||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX,
|
||||
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) {
|
||||
SDValidation.validateNumerical("tensorMmul", "x", x);
|
||||
SDValidation.validateNumerical("tensorMmul", "y", y);
|
||||
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
|
||||
|
@ -4389,7 +4342,7 @@ public class SDBaseOps {
|
|||
* @return output Output variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[] dimensionsX,
|
||||
int... dimensionsY) {
|
||||
int... dimensionsY) {
|
||||
SDValidation.validateNumerical("tensorMmul", "x", x);
|
||||
SDValidation.validateNumerical("tensorMmul", "y", y);
|
||||
Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length);
|
||||
|
@ -4522,7 +4475,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentMax(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentMax", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4561,7 +4514,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentMean(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentMean", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4600,7 +4553,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentMin(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentMin", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4639,7 +4592,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentProd(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentProd", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4676,7 +4629,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentSqrtN(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4715,7 +4668,7 @@ public class SDBaseOps {
|
|||
* @return output Unsorted segment output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable unsortedSegmentSum(String name, SDVariable data, SDVariable segmentIds,
|
||||
int numSegments) {
|
||||
int numSegments) {
|
||||
SDValidation.validateNumerical("unsortedSegmentSum", "data", data);
|
||||
SDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(sd,data, segmentIds, numSegments).outputVariable();
|
||||
|
@ -4771,7 +4724,7 @@ public class SDBaseOps {
|
|||
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
|
||||
*/
|
||||
public SDVariable variance(SDVariable x, boolean biasCorrected, boolean keepDims,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("variance", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
return new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
|
||||
|
@ -4795,7 +4748,7 @@ public class SDBaseOps {
|
|||
* @return output reduced array of rank (input rank - num dimensions) (NUMERIC type)
|
||||
*/
|
||||
public SDVariable variance(String name, SDVariable x, boolean biasCorrected, boolean keepDims,
|
||||
int... dimensions) {
|
||||
int... dimensions) {
|
||||
SDValidation.validateNumerical("variance", "x", x);
|
||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.summarystats.Variance(sd,x, biasCorrected, keepDims, dimensions).outputVariable();
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
|
@ -250,7 +248,7 @@ public class SDBitwise extends SDOps {
|
|||
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
|
@ -265,7 +263,7 @@ public class SDBitwise extends SDOps {
|
|||
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
|
@ -348,7 +346,7 @@ public class SDBitwise extends SDOps {
|
|||
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
|
@ -363,7 +361,7 @@ public class SDBitwise extends SDOps {
|
|||
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -42,8 +42,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 2D Convolution layer operation - average pooling 2d<br>
|
||||
*
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -56,8 +55,7 @@ public class SDCNN extends SDOps {
|
|||
* 2D Convolution layer operation - average pooling 2d<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -70,9 +68,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 3D convolution layer operation - average pooling 3d <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -85,9 +81,7 @@ public class SDCNN extends SDOps {
|
|||
* 3D convolution layer operation - average pooling 3d <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -302,9 +296,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
|
@ -322,9 +314,7 @@ public class SDCNN extends SDOps {
|
|||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
|
@ -342,9 +332,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
* @return output Conv3d output variable (NUMERIC type)
|
||||
|
@ -359,9 +347,7 @@ public class SDCNN extends SDOps {
|
|||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
* @return output Conv3d output variable (NUMERIC type)
|
||||
|
@ -377,8 +363,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
|
@ -396,8 +381,7 @@ public class SDCNN extends SDOps {
|
|||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
|
@ -415,8 +399,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
* @return output result of deconv2d op (NUMERIC type)
|
||||
|
@ -432,8 +415,7 @@ public class SDCNN extends SDOps {
|
|||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
* @return output result of deconv2d op (NUMERIC type)
|
||||
|
@ -519,8 +501,7 @@ public class SDCNN extends SDOps {
|
|||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||
* = [mb, 2, 4, 4]<br>
|
||||
*
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
@ -537,8 +518,7 @@ public class SDCNN extends SDOps {
|
|||
* = [mb, 2, 4, 4]<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
@ -756,8 +736,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||
*
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
*/
|
||||
public SDVariable[] maxPoolWithArgmax(SDVariable input, Pooling2DConfig Pooling2DConfig) {
|
||||
|
@ -769,8 +748,7 @@ public class SDCNN extends SDOps {
|
|||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
*/
|
||||
public SDVariable[] maxPoolWithArgmax(String[] names, SDVariable input,
|
||||
|
@ -783,8 +761,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 2D Convolution layer operation - max pooling 2d <br>
|
||||
*
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -797,8 +774,7 @@ public class SDCNN extends SDOps {
|
|||
* 2D Convolution layer operation - max pooling 2d <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -811,9 +787,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -826,9 +800,7 @@ public class SDCNN extends SDOps {
|
|||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -841,8 +813,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
|
@ -862,8 +833,7 @@ public class SDCNN extends SDOps {
|
|||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
|
@ -883,8 +853,7 @@ public class SDCNN extends SDOps {
|
|||
/**
|
||||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param Conv2DConfig Configuration Object
|
||||
|
@ -902,8 +871,7 @@ public class SDCNN extends SDOps {
|
|||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param Conv2DConfig Configuration Object
|
||||
|
@ -964,8 +932,7 @@ public class SDCNN extends SDOps {
|
|||
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||
* = [mb, 2, 4, 4] <br>
|
||||
*
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
@ -982,8 +949,7 @@ public class SDCNN extends SDOps {
|
|||
* = [mb, 2, 4, 4] <br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -36,7 +36,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
|
@ -56,7 +56,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions,
|
||||
|
@ -116,7 +116,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param dimension Dimension to perform the cosine distance over
|
||||
* @return output Cosine distance loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -141,7 +141,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param dimension Dimension to perform the cosine distance over
|
||||
* @return output Cosine distance loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -202,6 +202,49 @@ public class SDLoss extends SDOps {
|
|||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||
*
|
||||
* @param targetLabels Label array (NUMERIC type)
|
||||
* @param logitInput Inputs (NUMERIC type)
|
||||
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||
* @return output Ctc loss (NUMERIC type)
|
||||
*/
|
||||
public SDVariable ctcLoss(SDVariable targetLabels, SDVariable logitInput,
|
||||
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
|
||||
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
|
||||
out.markAsLoss();
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param targetLabels Label array (NUMERIC type)
|
||||
* @param logitInput Inputs (NUMERIC type)
|
||||
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||
* @return output Ctc loss (NUMERIC type)
|
||||
*/
|
||||
public SDVariable ctcLoss(String name, SDVariable targetLabels, SDVariable logitInput,
|
||||
SDVariable targetLabelLengths, SDVariable logitInputLengths) {
|
||||
SDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||
SDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||
SDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||
SDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(sd,targetLabels, logitInput, targetLabelLengths, logitInputLengths).outputVariable();
|
||||
out.markAsLoss();
|
||||
return sd.updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hinge loss: a loss function used for training classifiers.<br>
|
||||
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||
|
@ -210,7 +253,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
|
@ -232,7 +275,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions,
|
||||
|
@ -297,7 +340,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param delta Loss function delta value
|
||||
* @return output Huber loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -324,7 +367,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param delta Loss function delta value
|
||||
* @return output Huber loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -423,7 +466,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param epsilon epsilon
|
||||
* @return output Log loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -445,7 +488,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param epsilon epsilon
|
||||
* @return output Log loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -499,7 +542,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -521,7 +564,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -585,7 +628,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable, scalar output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions,
|
||||
|
@ -608,7 +651,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable, scalar output (NUMERIC type)
|
||||
*/
|
||||
public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions,
|
||||
|
@ -666,13 +709,13 @@ public class SDLoss extends SDOps {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
|
@ -687,14 +730,14 @@ public class SDLoss extends SDOps {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions,
|
||||
|
@ -709,7 +752,7 @@ public class SDLoss extends SDOps {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param label Label array (NUMERIC type)
|
||||
|
@ -728,7 +771,7 @@ public class SDLoss extends SDOps {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param name name May be null. Name for the output variable
|
||||
|
@ -764,7 +807,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictionLogits Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -796,7 +839,7 @@ public class SDLoss extends SDOps {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictionLogits Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -872,7 +915,7 @@ public class SDLoss extends SDOps {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
@ -884,7 +927,7 @@ public class SDLoss extends SDOps {
|
|||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -901,7 +944,7 @@ public class SDLoss extends SDOps {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
@ -914,7 +957,7 @@ public class SDLoss extends SDOps {
|
|||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -932,7 +975,7 @@ public class SDLoss extends SDOps {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
@ -959,7 +1002,7 @@ public class SDLoss extends SDOps {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -144,22 +144,22 @@ public class SDRNN extends SDOps {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
|
@ -180,22 +180,22 @@ public class SDRNN extends SDOps {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
|
@ -218,22 +218,22 @@ public class SDRNN extends SDOps {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
|
@ -248,22 +248,22 @@ public class SDRNN extends SDOps {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param names names May be null. Arrays of names for the output variables.
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.autodiff.samediff.ops;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType;
|
||||
|
||||
import java.lang.String;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* Activations */
|
||||
public enum CellAct {
|
||||
TANH,
|
||||
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* Data format: "NCHW" or "NHWC" */
|
||||
public enum DataFormat {
|
||||
NCHW,
|
||||
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* Activations */
|
||||
public enum GateAct {
|
||||
TANH,
|
||||
|
||||
|
|
|
@ -1,32 +1,43 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling.
|
||||
* ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing.
|
||||
* ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling.
|
||||
* ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0.
|
||||
* ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation.
|
||||
* ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases.
|
||||
* ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */
|
||||
public enum ImageResizeMethod {
|
||||
ResizeBilinear, // as java require
|
||||
ResizeNearest,
|
||||
ResizeBilinear,
|
||||
|
||||
ResizeBicubic,
|
||||
ResizeArea,
|
||||
|
||||
ResizeNearest,
|
||||
|
||||
ResizeGaussian,
|
||||
ResizeLanczos3,
|
||||
|
||||
ResizeLanczos5,
|
||||
ResizeMitchellcubic;
|
||||
|
||||
ResizeMitchelcubic,
|
||||
|
||||
ResizeArea
|
||||
}
|
||||
|
|
|
@ -1,25 +1,28 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* for unidirectional: TNS: shape [timeLength, numExamples, inOutSize] - sometimes referred to as "time major"<br>
|
||||
* NST: shape [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout<br> for bidirectional:
|
||||
* T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) */
|
||||
public enum LSTMDataFormat {
|
||||
TNS,
|
||||
|
||||
|
|
|
@ -1,25 +1,30 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* direction <br>
|
||||
* FWD: 0 = fwd
|
||||
* BWD: 1 = bwd
|
||||
* BIDIR_SUM: 2 = bidirectional sum
|
||||
* BIDIR_CONCAT: 3 = bidirectional concat
|
||||
* BIDIR_EXTRA_DIM: 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) */
|
||||
public enum LSTMDirectionMode {
|
||||
FWD,
|
||||
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* Activations */
|
||||
public enum OutAct {
|
||||
TANH,
|
||||
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
|
|
|
@ -1,25 +1,28 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.enums;
|
||||
|
||||
/**
|
||||
* The data format of the input. Input shape depends on data format (in config):<br>
|
||||
* TNS -> [timeSteps, batchSize, inSize]<br>
|
||||
* NST -> [batchSize, inSize, timeSteps]<br>
|
||||
* NTS -> [batchSize, timeSteps, inSize]<br> */
|
||||
public enum RnnDataFormat {
|
||||
TNS,
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ public abstract class BaseLoss extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
protected static INDArray getWeights(INDArray weights, INDArray predictions){
|
||||
protected static INDArray getWeights(INDArray weights, INDArray predictions) {
|
||||
return (weights != null) ? weights : Nd4j.scalar(predictions.dataType(), 1.0);
|
||||
}
|
||||
|
||||
|
|
|
@ -20,29 +20,20 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.loss;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.CtcLossBp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLoss extends BaseLoss {
|
||||
public class CtcLoss extends DynamicCustomOp {
|
||||
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
public CtcLoss(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||
}
|
||||
|
||||
public CtcLoss(SameDiff sameDiff, SDVariable label, SDVariable predictions, SDVariable weights,
|
||||
LossReduce lossReduce) {
|
||||
this(sameDiff, lossReduce, predictions, weights, label);
|
||||
}
|
||||
|
||||
public CtcLoss(INDArray labels, INDArray predictions, INDArray weights, LossReduce lossReduce){
|
||||
super(lossReduce, predictions, weights, labels);
|
||||
}
|
||||
|
||||
public CtcLoss(){ }
|
||||
|
||||
|
@ -52,9 +43,9 @@ public class CtcLoss extends BaseLoss {
|
|||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||
//No external gradient
|
||||
//Args are: predictions, weights, label
|
||||
return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs();
|
||||
return new CtcLossBp(sameDiff, arg(0), arg(1), arg(2),arg(3)).outputs();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,17 +20,17 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.loss.bp;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class CtcLossBp extends BaseLossBp {
|
||||
public class CtcLossBp extends DynamicCustomOp {
|
||||
|
||||
|
||||
public CtcLossBp(SameDiff sameDiff, LossReduce lossReduce, SDVariable predictions, SDVariable weights, SDVariable labels){
|
||||
super(sameDiff, lossReduce, predictions, weights, labels);
|
||||
public CtcLossBp(SameDiff sameDiff, SDVariable targetLabels,SDVariable logitInputs,SDVariable targetLabelLengths,SDVariable logitInputLengths){
|
||||
super(sameDiff,new SDVariable[]{targetLabels,logitInputs,targetLabelLengths,logitInputLengths});
|
||||
}
|
||||
|
||||
public CtcLossBp(){ }
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
|
@ -134,7 +132,7 @@ public class NDBitwise {
|
|||
|
||||
/**
|
||||
* Bitwise left cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike #leftShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code leftShiftCyclic(01110000, 2) -> 11000001}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
|
@ -180,7 +178,7 @@ public class NDBitwise {
|
|||
|
||||
/**
|
||||
* Bitwise right cyclical shift operation. Supports broadcasting.<br>
|
||||
* Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":<br>
|
||||
* Unlike rightShift(INDArray, INDArray) the bits will "wrap around":<br>
|
||||
* {@code rightShiftCyclic(00001110, 2) -> 10000011}<br>
|
||||
*
|
||||
* @param x Input to be bit shifted (INT type)
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.enums.DataFormat;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -41,8 +41,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 2D Convolution layer operation - average pooling 2d<br>
|
||||
*
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -54,9 +53,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 3D convolution layer operation - average pooling 3d <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output after applying average pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -161,9 +158,7 @@ public class NDCNN {
|
|||
/**
|
||||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
|
@ -180,9 +175,7 @@ public class NDCNN {
|
|||
/**
|
||||
* Convolution 3D operation with optional bias <br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for conv3d. Rank 5 with shape [kernelDepth, kernelHeight, kernelWidth, inputChannels, outputChannels]. (NUMERIC type)
|
||||
* @param Conv3DConfig Configuration Object
|
||||
* @return output Conv3d output variable (NUMERIC type)
|
||||
|
@ -196,8 +189,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param bias Optional 1D bias array with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
|
@ -214,8 +206,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 2D deconvolution operation with optional bias<br>
|
||||
*
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to deconvolution 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param weights Weights for the 2d deconvolution operation. 4 dimensions with format [inputChannels, outputChannels, kernelHeight, kernelWidth] (NUMERIC type)
|
||||
* @param DeConv2DConfig Configuration Object
|
||||
* @return output result of deconv2d op (NUMERIC type)
|
||||
|
@ -263,8 +254,7 @@ public class NDCNN {
|
|||
* Example: if input has shape [mb, 8, 2, 2] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||
* = [mb, 2, 4, 4]<br>
|
||||
*
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
@ -373,8 +363,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 2D Convolution layer operation - Max pooling on the input and outputs both max values and indices <br>
|
||||
*
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
*/
|
||||
public INDArray[] maxPoolWithArgmax(INDArray input, Pooling2DConfig Pooling2DConfig) {
|
||||
|
@ -385,8 +374,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 2D Convolution layer operation - max pooling 2d <br>
|
||||
*
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling2DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -398,9 +386,7 @@ public class NDCNN {
|
|||
/**
|
||||
* 3D convolution layer operation - max pooling 3d operation.<br>
|
||||
*
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format
|
||||
* (shape [minibatch, channels, depth, height, width]) or NDHWC format
|
||||
* (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param input the input to average pooling 3d operation - 5d activations in NCDHW format (shape [minibatch, channels, depth, height, width]) or NDHWC format (shape [minibatch, depth, height, width, channels]) (NUMERIC type)
|
||||
* @param Pooling3DConfig Configuration Object
|
||||
* @return output Result after applying max pooling on the input (NUMERIC type)
|
||||
*/
|
||||
|
@ -412,8 +398,7 @@ public class NDCNN {
|
|||
/**
|
||||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param bias Optional bias, rank 1 with shape [outputChannels]. May be null. (NUMERIC type)
|
||||
|
@ -432,8 +417,7 @@ public class NDCNN {
|
|||
/**
|
||||
* Separable 2D convolution operation with optional bias <br>
|
||||
*
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param layerInput the input to max pooling 2d operation - 4d CNN (image) activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param depthWeights Separable conv2d depth weights. 4 dimensions with format [kernelHeight, kernelWidth, inputChannels, depthMultiplier] (NUMERIC type)
|
||||
* @param pointWeights Point weights, rank 4 with format [1, 1, inputChannels*depthMultiplier, outputChannels]. May be null (NUMERIC type)
|
||||
* @param Conv2DConfig Configuration Object
|
||||
|
@ -471,8 +455,7 @@ public class NDCNN {
|
|||
* Example: if input has shape [mb, 2, 4, 4] and block size is 2, then output size is [mb, 8/(2*2), 2*2, 2*2]<br>
|
||||
* = [mb, 2, 4, 4] <br>
|
||||
*
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format
|
||||
* (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param x the input to depth to space pooling 2d operation - 4d activations in NCHW format (shape [minibatch, channels, height, width]) or NHWC format (shape [minibatch, height, width, channels]) (NUMERIC type)
|
||||
* @param blockSize Block size, in the height/width dimension
|
||||
* @param dataFormat Data format: "NCHW" or "NHWC"
|
||||
* @return output Output variable (NUMERIC type)
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.enums.ImageResizeMethod;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.NDValidation;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.NDValidation;
|
||||
|
@ -35,7 +35,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output loss variable (NUMERIC type)
|
||||
*/
|
||||
public INDArray absoluteDifference(INDArray label, INDArray predictions, INDArray weights,
|
||||
|
@ -71,7 +71,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param dimension Dimension to perform the cosine distance over
|
||||
* @return output Cosine distance loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -104,6 +104,25 @@ public class NDLoss {
|
|||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* CTC Loss: Connectionist Temporal Classification Loss. See:<br>
|
||||
* https://dl.acm.org/citation.cfm?id=1143891<br>
|
||||
*
|
||||
* @param targetLabels Label array (NUMERIC type)
|
||||
* @param logitInput Inputs (NUMERIC type)
|
||||
* @param targetLabelLengths Length of the target label (NUMERIC type)
|
||||
* @param logitInputLengths Length of the input (NUMERIC type)
|
||||
* @return output Ctc loss (NUMERIC type)
|
||||
*/
|
||||
public INDArray ctcLoss(INDArray targetLabels, INDArray logitInput, INDArray targetLabelLengths,
|
||||
INDArray logitInputLengths) {
|
||||
NDValidation.validateNumerical("ctcLoss", "targetLabels", targetLabels);
|
||||
NDValidation.validateNumerical("ctcLoss", "logitInput", logitInput);
|
||||
NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", targetLabelLengths);
|
||||
NDValidation.validateNumerical("ctcLoss", "logitInputLengths", logitInputLengths);
|
||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.loss.CtcLoss(targetLabels, logitInput, targetLabelLengths, logitInputLengths))[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Hinge loss: a loss function used for training classifiers.<br>
|
||||
* Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}<br>
|
||||
|
@ -112,7 +131,7 @@ public class NDLoss {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public INDArray hingeLoss(INDArray label, INDArray predictions, INDArray weights,
|
||||
|
@ -152,7 +171,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param delta Loss function delta value
|
||||
* @return output Huber loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -204,7 +223,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param epsilon epsilon
|
||||
* @return output Log loss (NUMERIC type)
|
||||
*/
|
||||
|
@ -237,7 +256,7 @@ public class NDLoss {
|
|||
* @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type)
|
||||
* @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param full Boolean flag. true for logPoissonFull, false for logPoisson
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -275,7 +294,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable, scalar output (NUMERIC type)
|
||||
*/
|
||||
public INDArray meanPairwiseSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||
|
@ -306,13 +325,13 @@ public class NDLoss {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param label Label array (NUMERIC type)
|
||||
* @param predictions Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
public INDArray meanSquaredError(INDArray label, INDArray predictions, INDArray weights,
|
||||
|
@ -325,7 +344,7 @@ public class NDLoss {
|
|||
|
||||
/**
|
||||
* Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.<br>
|
||||
* When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))<br>
|
||||
* When averaged (using LossReduce#MEAN_BY_WEIGHT or LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT (the default))<br>
|
||||
* this is the mean squared error loss function.<br>
|
||||
*
|
||||
* @param label Label array (NUMERIC type)
|
||||
|
@ -357,7 +376,7 @@ public class NDLoss {
|
|||
* @param label Label array (NUMERIC type)
|
||||
* @param predictionLogits Predictions array (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -398,7 +417,7 @@ public class NDLoss {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
@ -410,7 +429,7 @@ public class NDLoss {
|
|||
* @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type)
|
||||
* @param logitPredictions Predictions array (pre-softmax) (NUMERIC type)
|
||||
* @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type)
|
||||
* @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT}
|
||||
* @param lossReduce Reduction type for the loss. See LossReduce for more details. Default: LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT
|
||||
* @param labelSmoothing Label smoothing value. Default value: 0
|
||||
* @return output Loss variable (NUMERIC type)
|
||||
*/
|
||||
|
@ -425,7 +444,7 @@ public class NDLoss {
|
|||
/**
|
||||
* Applies the softmax activation function to the input, then implement multi-class cross entropy:<br>
|
||||
* {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}<br>
|
||||
* If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* If LossReduce#NONE is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;<br>
|
||||
* otherwise, the output is a scalar.<br>
|
||||
* <p><br>
|
||||
* When label smoothing is > 0, the following label smoothing is used:<br>
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.enums.PartitionMode;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.enums.PadMode;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
|
||||
|
@ -85,22 +85,22 @@ public class NDRNN {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
|
@ -121,22 +121,22 @@ public class NDRNN {
|
|||
|
||||
/**
|
||||
* Long Short-Term Memory layer - Hochreiter 1997.<br>
|
||||
* SUPPORTS following data formats:\n<br>
|
||||
* for unidirectional: \n" +<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]\n<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]\n<br>
|
||||
* SUPPORTS following data formats:<br>
|
||||
* for unidirectional:<br>
|
||||
* TNS: shapes [timeLength, numExamples, inOutSize]<br>
|
||||
* NST: shapes [numExamples, inOutSize, timeLength]<br>
|
||||
* NTS: shapes [numExamples, timeLength, inOutSize]<br>
|
||||
* for bidirectional:\n<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)\n<br>
|
||||
* SUPPORTS following direction modes:\n<br>
|
||||
* for bidirectional:<br>
|
||||
* T2NS: shapes [timeLength, 2, numExamples, inOutSize] (for ONNX)<br>
|
||||
* SUPPORTS following direction modes:<br>
|
||||
* FWD: forward<br>
|
||||
* BWD: backward<br>
|
||||
* BIDIR_SUM: bidirectional sum\n<br>
|
||||
* BIDIR_CONCAT: bidirectional concat\n" +<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)"<br>
|
||||
* BIDIR_SUM: bidirectional sum<br>
|
||||
* BIDIR_CONCAT: bidirectional concat<br>
|
||||
* BIDIR_EXTRA_DIM: bidirectional extra output dim (in conjunction with format dataFormat - T2NS)<br>
|
||||
* You may use different gate configurations:<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum\n<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")\n<br>
|
||||
* specify gate/cell/out aplha/beta and numbers of activations for gate/cell/out described in activations enum<br>
|
||||
* ("RELU","SIGMOID","AFFINE","LEAKY_RELU","THRESHHOLD_RELU","SCALED_TAHN","HARD_SIGMOID","ELU","SOFTSIGN","SOFTPLUS")<br>
|
||||
* Also this layer supports MKLDNN (DNNL) and cuDNN acceleration<br>
|
||||
*
|
||||
* @param x Input, with shape dependent on the data format (in config). (NUMERIC type)
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//================== GENERATED CODE - DO NOT MODIFY THIS FILE ==================
|
||||
|
||||
package org.nd4j.linalg.factory.ops;
|
||||
|
||||
import static org.nd4j.linalg.factory.NDValidation.isSameType;
|
||||
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
|
Loading…
Reference in New Issue