Source code for mowl.nn.alc.falcon.module

import torch as th
from torch.utils import checkpoint
from mowl.nn.alc.module import ALCModule
from mowl.owlapi import OWLAPIAdapter, ClassExpressionType, OWLSubClassOfAxiom, \
    OWLEquivalentClassesAxiom, OWLDisjointClassesAxiom, OWLClassAssertionAxiom, \
    OWLObjectPropertyAssertionAxiom


[docs] class FALCONModule(ALCModule): """Implementation of the FALCON model [falcon2022]_, a fuzzy :math:`\\mathcal{ALC}` neural reasoner. Each class expression is mapped to a *fuzzy set* over a collection of (named and anonymous) entity embeddings, and axioms are scored through fuzzy logical operators. Based on the original implementation at https://github.com/bio-ontology-research-group/FALCON """ def __init__( self, nclasses, nentities, nrelations, heads_dict, tails_dict, embed_dim=128, anon_e=4, t_norm='product', max_measure='max', residuum='notCorD', loss_type='c', num_negs=4, device='cpu'): super().__init__() self.nentities = nentities self.anon_e = anon_e self.heads_dict = heads_dict self.tails_dict = tails_dict self.c_embedding = th.nn.Embedding(nclasses, embed_dim) self.r_embedding = th.nn.Embedding(nrelations, embed_dim) self.e_embedding = th.nn.Embedding(nentities, embed_dim) # Two-layer membership network, matching the original FALCON implementation. self.fc_0 = th.nn.Linear(embed_dim * 2, embed_dim) self.fc_1 = th.nn.Linear(embed_dim, 1) th.nn.init.xavier_uniform_(self.c_embedding.weight.data) th.nn.init.xavier_uniform_(self.r_embedding.weight.data) th.nn.init.xavier_uniform_(self.e_embedding.weight.data) th.nn.init.xavier_uniform_(self.fc_0.weight.data) th.nn.init.xavier_uniform_(self.fc_1.weight.data) self.max_measure = max_measure self.t_norm = t_norm self.loss_type = loss_type self.num_negs = num_negs self.nothing = th.zeros(self.nentities).to(device) self.zero_emb = th.zeros(embed_dim).to(device) self.residuum = residuum self.device = device self.adapter = OWLAPIAdapter() def _mem(self, c_emb, e_emb): """Membership-degree network: maps a (concept, entity) embedding pair to a fuzzy membership degree in ``[0, 1]``.""" emb = th.cat([c_emb, e_emb], dim=-1) hidden = th.nn.functional.leaky_relu(self.fc_0(emb), negative_slope=0.1) return th.sigmoid(self.fc_1(hidden)) def _logical_and(self, x, y): if self.t_norm == 'product': return x * y elif self.t_norm == 'minmax': x = x.unsqueeze(dim=-2) y = y.expand_as(x) return th.cat([x, y], dim=-2).min(dim=-2)[0] elif self.t_norm == 'Łukasiewicz': x = x.unsqueeze(dim=-2) y = y.expand_as(x) return (((x + y - 1) > 0) * (x + y - 1)).squeeze(dim=-2) else: raise ValueError def _logical_or(self, x, y): if self.t_norm == 'product': return x + y - x * y elif self.t_norm == 'minmax': x = x.unsqueeze(dim=-2) y = y.expand_as(x) return th.cat([x, y], dim=-2).max(dim=-2)[0] elif self.t_norm == 'Łukasiewicz': x = x.unsqueeze(dim=-2) y = y.expand_as(x) return 1 - ((((1 - x) + (1 - y) - 1) > 0) * ((1 - x) + (1 - y) - 1)).squeeze(dim=-2) else: raise ValueError def _logical_not(self, x): return 1 - x def _logical_residuum(self, r_fs, c_fs): if self.residuum == 'notCorD': return self._logical_or(self._logical_not(r_fs), c_fs) else: raise ValueError def _logical_exist(self, r_fs, c_fs): ret = self._logical_and(r_fs, c_fs).max(dim=-1)[0].unsqueeze(-1) return ret.expand_as(r_fs) def _logical_forall(self, r_fs, c_fs): ret = self._logical_residuum(r_fs, c_fs).min(dim=-1)[0].unsqueeze(-1) return ret.expand_as(r_fs) def _get_c_fs_batch(self, c_emb, e_emb): e_emb = e_emb.unsqueeze( dim=0).repeat(c_emb.size()[0], 1, 1) c_emb = c_emb.unsqueeze(dim=1).expand_as(e_emb) return self._mem(c_emb, e_emb).squeeze(dim=-1) def _get_r_fs_batch(self, r_emb, e_emb): e_emb = e_emb.unsqueeze( dim=0).repeat(r_emb.size()[0], 1, 1) r_emb = r_emb.unsqueeze(dim=1).expand_as(e_emb) return self._mem(e_emb + r_emb, e_emb).squeeze(dim=-1)
[docs] def sample_negatives(self, e, r, used_dict): ret = th.zeros((e.shape[0], self.num_negs), dtype=th.int64) for i in range(e.shape[0]): used = th.tensor(used_dict[(e[i].item(), r[i].item())]) neg_pool = th.ones(self.nentities) neg_pool[used] = 0 neg_pool = neg_pool.nonzero() neg = neg_pool[th.randint(len(neg_pool), (self.num_negs,))] ret[i, :] = neg.flatten() return ret
[docs] def forward_fs(self, cexpr, x, e_emb, cur_index=0): expr_type = cexpr.getClassExpressionType() if expr_type == ClassExpressionType.OWL_CLASS: c_emb = self.c_embedding(x[:, cur_index]) return self._get_c_fs_batch(c_emb, e_emb), cur_index + 1 elif expr_type == ClassExpressionType.OBJECT_SOME_VALUES_FROM: r_emb = self.r_embedding(x[:, cur_index]) r_fs = checkpoint.checkpoint(self._get_r_fs_batch, r_emb, e_emb, use_reentrant=False) c_fs, next_index = self.forward_fs( cexpr.getFiller(), x, e_emb, cur_index=cur_index + 1) return self._logical_exist(r_fs, c_fs), next_index elif expr_type == ClassExpressionType.OBJECT_ALL_VALUES_FROM: r_emb = self.r_embedding(x[:, cur_index]) r_fs = checkpoint.checkpoint(self._get_r_fs_batch, r_emb, e_emb, use_reentrant=False) c_fs, next_index = self.forward_fs( cexpr.getFiller(), x, e_emb, cur_index=cur_index + 1) return self._logical_forall(r_fs, c_fs), next_index elif expr_type == ClassExpressionType.OBJECT_INTERSECTION_OF: cexprs = cexpr.getOperandsAsList() ret, next_index = self.forward_fs(cexprs[0], x, e_emb, cur_index=cur_index) for i in range(1, len(cexprs)): next_ret, next_index = self.forward_fs(cexprs[i], x, e_emb, cur_index=next_index) ret = self._logical_and(ret, next_ret) return ret, next_index elif expr_type == ClassExpressionType.OBJECT_UNION_OF: cexprs = cexpr.getOperandsAsList() ret, next_index = self.forward_fs(cexprs[0], x, e_emb, cur_index=cur_index) for i in range(1, len(cexprs)): next_ret, next_index = self.forward_fs(cexprs[i], x, e_emb, cur_index=next_index) ret = self._logical_or(ret, next_ret) return ret, next_index elif expr_type == ClassExpressionType.OBJECT_COMPLEMENT_OF: ret, next_index = self.forward_fs( cexpr.getOperand(), x, e_emb, cur_index=cur_index) return self._logical_not(ret), next_index raise NotImplementedError()
[docs] def get_cc_loss(self, fs): if self.max_measure == 'max': return - th.log(1 - fs.max(dim=-1)[0] + 1e-10) elif self.max_measure[:5] == 'pmean': p = int(self.max_measure[-1]) return - th.log(1 - ((fs ** p).mean(dim=-1))**(1 / p) + 1e-10) else: raise ValueError
[docs] def forward(self, axiom, x, e_emb, stage='train'): if isinstance(axiom, OWLSubClassOfAxiom): # C ⊑ D is violated where C holds and D does not, i.e. on C ⊓ ¬D. We walk # the sub- and super-class expressions sequentially (rather than rebuilding # an intersection, whose operands OWLAPI would reorder) so that the columns # of ``x`` are consumed in the same order they were produced. C = axiom.getSubClass() D = axiom.getSuperClass() c_fs, next_index = self.forward_fs(C, x, e_emb) d_fs, _ = self.forward_fs(D, x, e_emb, cur_index=next_index) fs = self._logical_and(c_fs, self._logical_not(d_fs)) return self.get_cc_loss(fs).mean() elif isinstance(axiom, OWLEquivalentClassesAxiom): cexprs = axiom.getClassExpressionsAsList() C, D = cexprs[0], cexprs[1] c_fs, next_index = self.forward_fs(C, x, e_emb) d_fs, next_index = self.forward_fs(D, x, e_emb, cur_index=next_index) fs1 = self._logical_and(c_fs, self._logical_not(d_fs)) fs2 = self._logical_and(self._logical_not(c_fs), d_fs) return self.get_cc_loss(fs1).mean() + self.get_cc_loss(fs2).mean() elif isinstance(axiom, OWLDisjointClassesAxiom): # C and D are disjoint iff their intersection C ⊓ D is unsatisfiable. Walk the # two expressions sequentially to consume the columns of ``x`` in order. cexprs = axiom.getClassExpressionsAsList() C, D = cexprs[0], cexprs[1] c_fs, next_index = self.forward_fs(C, x, e_emb) d_fs, _ = self.forward_fs(D, x, e_emb, cur_index=next_index) fs = self._logical_and(c_fs, d_fs) return self.get_cc_loss(fs).mean() elif isinstance(axiom, OWLClassAssertionAxiom): x = x.unsqueeze(dim=1) size = [1] * len(x.size()) size[1] = self.num_negs neg_x = x.repeat(size) neg_ents = th.randint(self.nentities, (x.shape[0], self.num_negs)) neg_x[:, :, 0] = neg_ents x = th.cat([x, neg_x], dim=1) cexpr = axiom.getClassExpression() expr_type = cexpr.getClassExpressionType() if expr_type == ClassExpressionType.OWL_CLASS: # Plain concept assertion ``e : C`` — directly push the membership # μ(C, e) towards 1 for the asserted individual and towards 0 for the # corrupted negatives (cf. ``forward_abox_ec_created`` in the original # FALCON implementation). c_emb = self.c_embedding(x[:, 0, 1]).unsqueeze(dim=1) ex_emb = self.e_embedding(x[:, :, 0]) dofm = self._mem(c_emb.expand_as(ex_emb), ex_emb).squeeze(dim=-1) res = (- th.log(dofm[:, 0] + 1e-10).mean() - th.log(1 - dofm[:, 1:] + 1e-10).mean()) return res / 2 # Relational or complex assertion (e.g. ``e : ∃R.C``) — exist-based loss. r = None if expr_type == ClassExpressionType.OBJECT_SOME_VALUES_FROM: r = cexpr.getProperty() rx = x[:, :, 1] cexpr = cexpr.getFiller() cx = x[:, 0, 2:] else: cx = x[:, 0, 1:] c_fs, _ = self.forward_fs(cexpr, cx, e_emb) if r is not None: r_emb = self.r_embedding(rx) else: r_emb = 0 ex = x[:, :, 0] ex_emb = self.e_embedding(ex) r_fs = self._get_c_fs_batch( (ex_emb + r_emb).view(-1, ex_emb.shape[-1]), e_emb).view( ex_emb.shape[0], ex_emb.shape[1], -1) c_fs = c_fs.unsqueeze(dim=1) dofm = self._logical_exist(r_fs, c_fs) res = (- th.log(dofm[:, 0] + 1e-10).mean() - th.log(1 - dofm[:, 1:] + 1e-10).mean()) return res / 2 elif isinstance(axiom, OWLObjectPropertyAssertionAxiom): x = x.unsqueeze(dim=1) size = [1] * len(x.size()) size[1] = self.num_negs neg_h = x.repeat(size) neg_ents = self.sample_negatives(x[:, :, 2], x[:, :, 1], self.heads_dict) neg_h[:, :, 0] = neg_ents neg_t = x.repeat(size) neg_ents = self.sample_negatives(x[:, :, 0], x[:, :, 1], self.tails_dict) neg_t[:, :, 2] = neg_ents x = th.cat([x, neg_h, neg_t], dim=1) e_1_emb = self.e_embedding(x[:, :, 0]) r_emb = self.r_embedding(x[:, :, 1]) e_2_emb = self.e_embedding(x[:, :, 2]) if stage == 'train': if self.loss_type == 'c': dofm = self._mem(e_1_emb + r_emb, e_2_emb) res = - th.log(dofm[:, 0] + 1e-10).mean() - \ th.log(1 - dofm[:, 1:] + 1e-10).mean() return res / 2 elif self.loss_type == 'r': dofm = self._mem(e_1_emb + r_emb, e_2_emb).squeeze(dim=-1) diff = dofm[:, 0].unsqueeze(dim=-1) - dofm[:, 1:] return - th.nn.functional.logsigmoid(diff).mean() else: raise NotImplementedError() elif stage == 'test': return self._mem(e_1_emb + r_emb, e_2_emb).flatten() else: raise NotImplementedError()