View source
on GitHub
Root node binary (or multi-label) classification.
Inherits From: Task
runner.RootNodeBinaryClassification(
node_set_name: str,
units: int = 1,
*,
state_name: str = tfgnn.HIDDEN_STATE,
name: str = 'classification_logits',
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None
)
gather_activations(
inputs: GraphTensor
) -> Field
Gather activations from root nodes.
losses() -> interfaces.Losses
Returns arbitrary task specific losses.
metrics() -> interfaces.Metrics
Returns arbitrary task specific metrics.
predict(
inputs: tfgnn.GraphTensor
) -> interfaces.Predictions
Apply a linear head for classification.
| Args | |
|---|---|
inputs
|
A tfgnn.GraphTensor for classification.
|
| Returns | |
|---|---|
| The classification logits. |
preprocess(
inputs: GraphTensor
) -> tuple[GraphTensor, Field]
Preprocesses a scalar (after merge_batch_to_components) GraphTensor.
This function uses the Keras functional API to define non-trainable
transformations of the symbolic input GraphTensor, which get executed during
dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two
responsibilities:
- Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
- Optionally, transforming input features. Some advanced modeling techniques
require running the same base GNN on multiple different transformations, so
this function may return a single
GraphTensoror a non-empty sequence ofGraphTensors. The corresponding base GNN output for eachGraphTensoris provided to thepredict(...)method.
| Args | |
|---|---|
inputs
|
A symbolic Keras GraphTensor for processing.
|
| Returns | |
|---|---|
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to
be used as labels.
|