terminator.models.layers.s2s_modules.NodeMPNNLayer

class terminator.models.layers.s2s_modules.NodeMPNNLayer(num_hidden, num_in, dropout=0.1, num_heads=None, scale=30)[source]

Bases: Module

__init__(num_hidden, num_in, dropout=0.1, num_heads=None, scale=30)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(num_hidden, num_in[, dropout, ...])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(h_V, h_E[, mask_V, mask_attend])

Parallel computation of full transformer layer

Attributes

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

dump_patches

This allows better BC support for load_state_dict().

forward(h_V, h_E, mask_V=None, mask_attend=None)[source]

Parallel computation of full transformer layer