| | import numpy as np |
| | import torch |
| | from scipy.stats import pearsonr |
| | import dataloader_gosai |
| | import oracle |
| |
|
| |
|
| | def compare_kmer(kmer1, kmer2, n_sp1, n_sp2): |
| | kmer_set = set(kmer1.keys()) | set(kmer2.keys()) |
| | counts = np.zeros((len(kmer_set), 2)) |
| | for i, kmer in enumerate(kmer_set): |
| | if kmer in kmer1: counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1 |
| | if kmer in kmer2: counts[i][0] = kmer2[kmer] |
| | return pearsonr(counts[:, 0], counts[:, 1])[0] |
| |
|
| |
|
| | def get_eval_matrics(samples, ref_model, gosai_oracle, cal_atac_pred_new_mdl, highexp_kmers_999, n_highexp_kmers_999): |
| | """samples: [B, 200]""" |
| | info = {} |
| | detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples.detach().cpu().numpy()) |
| | ref_log_lik = ref_model.get_likelihood(samples, num_steps=128, n_samples=1) |
| | info['[log-lik-med]'] = torch.median(ref_log_lik).item() |
| | preds = oracle.cal_gosai_pred_new(detokenized_samples, gosai_oracle, mode='eval')[:, 0] |
| | info['[pred-activity-med]'] = np.median(preds).item() |
| | atac = oracle.cal_atac_pred_new(detokenized_samples, cal_atac_pred_new_mdl)[:, 1] |
| | info['[atac-acc%]'] = (atac > 0.5).sum().item() / len(samples) * 100 |
| | kmer = oracle.count_kmers(detokenized_samples) |
| | info['[3-mer-corr]'] = compare_kmer(highexp_kmers_999, kmer, n_highexp_kmers_999, len(detokenized_samples)).item() |
| | return info |