terminator.models.layers.gvp.GVPConvLayer

class terminator.models.layers.gvp.GVPConvLayer(node_dims, edge_dims, n_message=3, n_feedforward=2, drop_rate=0.1, autoregressive=False, activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]

Bases: Module

Full graph convolution / message passing layer with Geometric Vector Perceptrons. Residually updates node embeddings with aggregated incoming messages, applies a pointwise feedforward network to node embeddings, and returns updated node embeddings.

To only compute the aggregated messages, see GVPConv.

Parameters:
  • node_dims – node embedding dimensions (n_scalar, n_vector)

  • edge_dims – input edge embedding dimensions (n_scalar, n_vector)

  • n_message – number of GVPs to use in message function

  • n_feedforward – number of GVPs to use in feedforward function

  • drop_rate – drop probability in all dropout layers

  • autoregressive – if True, this GVPConvLayer will be used with a different set of input node embeddings for messages where src >= dst

  • activations – tuple of functions (scalar_act, vector_act) to use in GVPs

  • vector_gate – whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if True)

__init__(node_dims, edge_dims, n_message=3, n_feedforward=2, drop_rate=0.1, autoregressive=False, activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(node_dims, edge_dims[, n_message, ...])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, edge_attr[, ...])

param x:

tuple (s, V) of torch.Tensor

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(x, edge_index, edge_attr, autoregressive_x=None, node_mask=None)[source]
Parameters:
  • x – tuple (s, V) of torch.Tensor

  • edge_index – array of shape [2, n_edges]

  • edge_attr – tuple (s, V) of torch.Tensor

  • autoregressive_x – tuple (s, V) of torch.Tensor. If not None, will be used as src node embeddings for forming messages where src >= dst. The corrent node embeddings x will still be the base of the update and the pointwise feedforward.

  • node_mask – array of type bool to index into the first dim of node embeddings (s, V). If not None, only these nodes will be updated.