terminator.models.layers.gvp.GVPConv

class terminator.models.layers.gvp.GVPConv(in_dims, out_dims, edge_dims, n_layers=3, module_list=None, aggr='mean', activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]

Bases: MessagePassing

Graph convolution / message passing with Geometric Vector Perceptrons. Takes in a graph with node and edge embeddings, and returns new node embeddings.

This does NOT do residual updates and pointwise feedforward layers —see GVPConvLayer.

Parameters:
  • in_dims – input node embedding dimensions (n_scalar, n_vector)

  • out_dims – output node embedding dimensions (n_scalar, n_vector)

  • edge_dims – input edge embedding dimensions (n_scalar, n_vector)

  • n_layers – number of GVPs in the message function

  • module_list – preconstructed message function, overrides n_layers

  • aggr – should be “add” if some incoming edges are masked, as in a masked autoregressive decoder architecture, otherwise “mean”

  • activations – tuple of functions (scalar_act, vector_act) to use in GVPs

  • vector_gate – whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if True)

__init__(in_dims, out_dims, edge_dims, n_layers=3, module_list=None, aggr='mean', activations=(<function relu>, <built-in method sigmoid of type object>), vector_gate=False)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(in_dims, out_dims, edge_dims[, ...])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, edge_index, edge_attr)

param x:

tuple (s, V) of torch.Tensor

message(s_i, v_i, s_j, v_j, edge_attr)

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index.

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

special_args

forward(x, edge_index, edge_attr)[source]
Parameters:
  • x – tuple (s, V) of torch.Tensor

  • edge_index – array of shape [2, n_edges]

  • edge_attr – tuple (s, V) of torch.Tensor

message(s_i, v_i, s_j, v_j, edge_attr)[source]

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.