View source
on GitHub
A general purpose runner for TF-GNN.
class ContextLabelFn: Reads out a tfgnn.Field
from the GraphTensor context.
class DatasetProvider: Helper class that
provides a standard way to create an ABC using inheritance.
class DotProductLinkPrediction:
Implements edge score as dot product of features of endpoint nodes.
class FitOrSkipPadding: Calculates fit or skip
SizeConstraints for GraphTensor padding.
class GraphBinaryClassification:
Graph binary (or multi-label) classification from pooled node states.
class GraphMeanAbsoluteError: Regression
from pooled node states with mean absolute error.
class GraphMeanAbsolutePercentageError:
Regression from pooled node states with mean absolute percentage error.
class GraphMeanSquaredError: Regression
from pooled node states with mean squared error.
class GraphMeanSquaredLogScaledError:
Regression from pooled node states with mean squared log scaled error.
class GraphMeanSquaredLogarithmicError:
Regression from pooled node states with mean squared logarithmic error.
class GraphMulticlassClassification:
Graph multiclass classification from pooled node states.
class GraphTensorPadding: Collects
GraphtTensor padding helpers.
class GraphTensorProcessorFn: A class
for GraphTensor processing.
class HadamardProductLinkPrediction:
Implements edge score as hadamard product of features of endpoint nodes.
class IntegratedGradientsExporter:
Exports a Keras model with an additional integrated gradients signature.
class KerasModelExporter: Exports a Keras
model (with Keras API) via tf.keras.models.save_model.
class KerasTrainer: Trains using the
tf.keras.Model.fit training loop.
class KerasTrainerCheckpointOptions:
Provides Keras Checkpointing related configuration options.
class KerasTrainerOptions: Provides Keras
training related options.
class ModelExporter: Saves a Keras model.
class NodeBinaryClassification: Node
binary (or multi-label) classification via structured readout.
class NodeMeanAbsoluteError: Node
regression with mean absolute error via structured readout.
class NodeMeanAbsolutePercentageError:
Node regression with mean absolute percentage error via structured readout.
class NodeMeanSquaredError: Node
regression with mean squared error via structured readout.
class NodeMeanSquaredLogScaledError:
Node regression with mean squared log scaled error via structured readout.
class NodeMeanSquaredLogarithmicError:
Node regression with mean squared log error via structured readout.
class NodeMulticlassClassification:
Node multiclass classification via structured readout.
class ParameterServerStrategy: A
ParameterServerStrategy convenience wrapper.
class PassthruDatasetProvider: Builds a
tf.data.Dataset from a pass thru dataset.
class PassthruSampleDatasetsProvider:
Builds a sampled tf.data.Dataset from multiple pass thru datasets.
class RootNodeBinaryClassification:
Root node binary (or multi-label) classification.
class RootNodeLabelFn: Reads out a
tfgnn.Field from the GraphTensor root (i.e. first) node.
class RootNodeMeanAbsoluteError: Root
node regression with mean absolute error.
class RootNodeMeanAbsolutePercentageError:
Root node regression with mean absolute percentage error.
class RootNodeMeanSquaredError: Root
node regression with mean squared error.
class RootNodeMeanSquaredLogScaledError:
Root node regression with mean squared log scaled error.
class RootNodeMeanSquaredLogarithmicError:
Root node regression with mean squared logarithmic error.
class RootNodeMulticlassClassification:
Root node multiclass classification.
class RunResult: Holds the return values of
run(...).
class SampleTFRecordDatasetsProvider:
Builds a sampling tf.data.Dataset from multiple filenames.
class SimpleDatasetProvider: Builds a
tf.data.Dataset from a list of files.
class SimpleSampleDatasetsProvider:
Builds a sampling tf.data.Dataset from multiple filenames.
class SubmoduleExporter: Exports a Keras
submodule.
class TFDataServiceConfig: Provides tf.data
service related configuration options.
class TFRecordDatasetProvider: Builds a
tf.data.Dataset from a list of files.
class TPUStrategy: A TPUStrategy convenience
wrapper.
class Task: Defines a learning objective for a GNN.
class TightPadding: Calculates tight
SizeConstraints for GraphTensor padding.
class Trainer: A class for training and validation of a
Keras model.
export_model(...): Exports a Keras model without
traces s.t. it is loadable without TF-GNN.
incrementing_model_dir(...): Create,
given some dirname, an incrementing model directory.
integrated_gradients(...): Integrated
gradients.
one_node_per_component(...): Returns a
Mapping node_set_name: 1 for every node set in gtspec.
run(...): Runs training (and validation) of a model on
task(s) with the given data.