File size: 658 Bytes
52007f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#!/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np


class ClusterResampler:
    def __init__(self, cluster_path: str) -> None:
        idx2prob = []
        with open(cluster_path, 'r') as fin:
            for line in fin:
                cluster_n_member = int(line.strip().split('\t')[-1])
                idx2prob.append(1 / cluster_n_member)
        total = sum(idx2prob)
        idx2prob = [p / total for p in idx2prob]
        self.idx2prob = np.array(idx2prob)

    def __call__(self, n_sample:int, replace: bool=False):
        idxs = np.random.choice(len(self.idx2prob), size=n_sample, replace=replace, p=self.idx2prob)
        return idxs