LinkPredictContrastiveRotatEDecoder

class graphstorm.model.LinkPredictContrastiveRotatEDecoder(etypes, h_dim, gamma=12.0)

Bases: LinkPredictRotatEDecoder

forward(g, h, e_h=None)

Link prediction decoder forward function using the RotatE as the score function.

This computes the edge score on every edge type.

Parameters

g: DGLGraph

The input graph.

h: dict of Tensor

The input node embeddings in the format of {ntype: emb}.

e_h: dict of Tensor

The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. Not used, but reserved for future support of edge embeddings. Default: None.

Returns

scores: dict of Tensor

The scores for edges of all edge types in the input graph in the format of {(src_ntype, etype, dst_ntype): score}.