Source code for mowl.models.elembeddings.examples.model_gda

from mowl.base_models.elmodel import EmbeddingELModel

from tqdm import trange, tqdm
import torch as th

import numpy as np
from mowl.models import ELEmbeddings

[docs] class ELEmGDA(ELEmbeddings): """ Example of ELEmbeddings for gene-disease associations prediction. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def train(self): _, diseases = self.dataset.evaluation_classes optimizer = th.optim.Adam(self.module.parameters(), lr=self.learning_rate) best_loss = float('inf') for epoch in trange(self.epochs): self.module.train() train_loss = 0 loss = 0 for gci_name, gci_dataset in self.training_datasets.items(): if len(gci_dataset) == 0: continue loss += th.mean(self.module(gci_dataset[:], gci_name)) if gci_name == "gci2": idxs_for_negs = np.random.choice(len(self.class_index_dict), size=len(gci_dataset), replace=True) rand_index = th.tensor(idxs_for_negs).to(self.device) data = gci_dataset[:] neg_data = th.cat([data[:, :2], rand_index.unsqueeze(1)], dim=1) loss += th.mean(self.module(neg_data, gci_name, neg=True)) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.detach().item() loss = 0 with th.no_grad(): self.module.eval() valid_loss = 0 gci2_data = self.validation_datasets["gci2"][:] loss = th.mean(self.module(gci2_data, "gci2")) valid_loss += loss.detach().item() checkpoint = 10 if best_loss > valid_loss: best_loss = valid_loss th.save(self.module.state_dict(), self.model_filepath) if (epoch + 1) % (checkpoint * 10) == 0: print(f'Epoch {epoch}: Train loss: {train_loss} Valid loss: {valid_loss}') return 1