Source code for mowl.nn.el.boxsquaredel.module

import mowl.nn.el.boxsquaredel.losses as L
from mowl.nn import ELModule
import torch as th
import torch.nn as nn


[docs] class BoxSquaredELModule(ELModule): """ Implementation of Box :math:`^2` EL from [jackermeier2023]_. """ neg_capable_gcis = frozenset({"gci2", "object_property_assertion"}) def __init__(self, nb_ont_classes, nb_rels, nb_inds=None, embed_dim=50, gamma=0, delta = 2, reg_factor = 0.05): super().__init__() self.nb_ont_classes = nb_ont_classes self.nb_rels = nb_rels self.nb_inds = nb_inds self.embed_dim = embed_dim self.class_center = self.init_embeddings(nb_ont_classes, embed_dim) self.class_offset = self.init_embeddings(nb_ont_classes, embed_dim) self.head_center = self.init_embeddings(nb_rels, embed_dim) self.head_offset = self.init_embeddings(nb_rels, embed_dim) self.tail_center = self.init_embeddings(nb_rels, embed_dim) self.tail_offset = self.init_embeddings(nb_rels, embed_dim) self.bump_classes = self.init_embeddings(nb_ont_classes, embed_dim) if self.nb_inds is not None and self.nb_inds > 0: self.bump_individuals = self.init_embeddings(nb_inds, embed_dim) self.ind_center = self.init_embeddings(nb_inds, embed_dim) self.ind_offset = self.init_embeddings(nb_inds, embed_dim) else: self.bump_individuals = None self.ind_center = None self.ind_offset = None self.gamma = gamma self.delta = delta self.reg_factor = reg_factor
[docs] def init_embeddings(self, num_entities, embed_dim, min=-1, max=1): embeddings = nn.Embedding(num_entities, embed_dim) nn.init.uniform_(embeddings.weight, a=min, b=max) embeddings.weight.data /= th.linalg.norm(embeddings.weight.data, axis=1).reshape(-1, 1) return embeddings
[docs] def gci0_loss(self, data, neg=False): return L.gci0_loss(data, self.class_center, self.class_offset, self.gamma, neg=neg)
[docs] def gci0_bot_loss(self, data, neg=False): return L.gci0_bot_loss(data, self.class_offset)
[docs] def gci1_loss(self, data, neg=False): return L.gci1_loss(data, self.class_center, self.class_offset, self.gamma, neg=neg)
[docs] def gci1_bot_loss(self, data, neg=False): return L.gci1_bot_loss(data, self.class_center, self.class_offset, self.gamma, neg=neg)
[docs] def gci2_loss(self, data, neg=False): return L.gci2_loss(data, self.class_center, self.class_offset, self.head_center, self.head_offset, self.tail_center, self.tail_offset, self.bump_classes, self.gamma, self.delta, neg=neg)
[docs] def gci3_loss(self, data, neg=False): return L.gci3_loss(data, self.class_center, self.class_offset, self.head_center, self.head_offset, self.tail_center, self.tail_offset, self.bump_classes, self.gamma, neg=neg)
[docs] def gci3_bot_loss(self, data, neg=False): return L.gci3_bot_loss(data, self.head_offset)
[docs] def class_assertion_loss(self, data, neg=False): if self.ind_center is None: raise ValueError("The number of individuals must be specified to use this loss function.") return L.class_assertion_loss(data, self.ind_center, self.ind_offset, self.class_center, self.class_offset, self.gamma, neg=neg)
[docs] def object_property_assertion_loss(self, data, neg=False): if self.ind_center is None: raise ValueError("The number of individuals must be specified to use this loss function.") return L.object_property_assertion_loss(data, self.ind_center, self.ind_offset, self.head_center, self.head_offset, self.tail_center, self.tail_offset, self.bump_individuals, self.gamma, self.delta, neg=neg)
[docs] def regularization_loss(self): loss = L.reg_loss(self.bump_classes, self.reg_factor) if self.bump_individuals is not None: loss += L.reg_loss(self.bump_individuals, self.reg_factor) return loss