terminator.models.layers.gvp.GVPConvLayer¶
- class terminator.models.layers.gvp.GVPConvLayer(node_dims, edge_dims, n_message=3, n_feedforward=2, drop_rate=0.1, autoregressive=False, activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]¶
Bases:
ModuleFull graph convolution / message passing layer with Geometric Vector Perceptrons. Residually updates node embeddings with aggregated incoming messages, applies a pointwise feedforward network to node embeddings, and returns updated node embeddings.
To only compute the aggregated messages, see GVPConv.
- Parameters:
node_dims – node embedding dimensions (n_scalar, n_vector)
edge_dims – input edge embedding dimensions (n_scalar, n_vector)
n_message – number of GVPs to use in message function
n_feedforward – number of GVPs to use in feedforward function
drop_rate – drop probability in all dropout layers
autoregressive – if True, this GVPConvLayer will be used with a different set of input node embeddings for messages where src >= dst
activations – tuple of functions (scalar_act, vector_act) to use in GVPs
vector_gate – whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if True)
- __init__(node_dims, edge_dims, n_message=3, n_feedforward=2, drop_rate=0.1, autoregressive=False, activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(node_dims, edge_dims[, n_message, ...])Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(x, edge_index, edge_attr[, ...])- param x:
tuple (s, V) of torch.Tensor
Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().- forward(x, edge_index, edge_attr, autoregressive_x=None, node_mask=None)[source]¶
- Parameters:
x – tuple (s, V) of torch.Tensor
edge_index – array of shape [2, n_edges]
edge_attr – tuple (s, V) of torch.Tensor
autoregressive_x – tuple (s, V) of torch.Tensor. If not None, will be used as src node embeddings for forming messages where src >= dst. The corrent node embeddings x will still be the base of the update and the pointwise feedforward.
node_mask – array of type bool to index into the first dim of node embeddings (s, V). If not None, only these nodes will be updated.