terminator.models.layers.gvp.GVPConv¶
- class terminator.models.layers.gvp.GVPConv(in_dims, out_dims, edge_dims, n_layers=3, module_list=None, aggr='mean', activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]¶
Bases:
MessagePassingGraph convolution / message passing with Geometric Vector Perceptrons. Takes in a graph with node and edge embeddings, and returns new node embeddings.
This does NOT do residual updates and pointwise feedforward layers —see GVPConvLayer.
- Parameters:
in_dims – input node embedding dimensions (n_scalar, n_vector)
out_dims – output node embedding dimensions (n_scalar, n_vector)
edge_dims – input edge embedding dimensions (n_scalar, n_vector)
n_layers – number of GVPs in the message function
module_list – preconstructed message function, overrides n_layers
aggr – should be “add” if some incoming edges are masked, as in a masked autoregressive decoder architecture, otherwise “mean”
activations – tuple of functions (scalar_act, vector_act) to use in GVPs
vector_gate – whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if True)
- __init__(in_dims, out_dims, edge_dims, n_layers=3, module_list=None, aggr='mean', activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]¶
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__(in_dims, out_dims, edge_dims[, ...])Initializes internal Module state, shared by both nn.Module and ScriptModule.
forward(x, edge_index, edge_attr)- param x:
tuple (s, V) of torch.Tensor
message(s_i, v_i, s_j, v_j, edge_attr)Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index.Attributes
T_destinationalias of TypeVar('T_destination', bound=
Mapping[str,Tensor])dump_patchesThis allows better BC support for
load_state_dict().special_args- forward(x, edge_index, edge_attr)[source]¶
- Parameters:
x – tuple (s, V) of torch.Tensor
edge_index – array of shape [2, n_edges]
edge_attr – tuple (s, V) of torch.Tensor
- message(s_i, v_i, s_j, v_j, edge_attr)[source]¶
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.