LinkPredictContrastiveRotatEDecoder
- class graphstorm.model.LinkPredictContrastiveRotatEDecoder(etypes, h_dim, gamma=12.0)
Bases:
LinkPredictRotatEDecoder- forward(g, h, e_h=None)
Link prediction decoder forward function using the RotatE as the score function.
This computes the edge score on every edge type.
Parameters
- g: DGLGraph
The input graph.
- h: dict of Tensor
The input node embeddings in the format of {ntype: emb}.
- e_h: dict of Tensor
The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. Not used, but reserved for future support of edge embeddings. Default: None.
Returns
- scores: dict of Tensor
The scores for edges of all edge types in the input graph in the format of {(src_ntype, etype, dst_ntype): score}.