""" Copyright (c) Microsoft Corporation. Licensed under the MIT license. Definition of TSV class """ import logging import os import os.path as op def generate_lineidx(filein, idxout): idxout_tmp = idxout + '.tmp' with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: fsize = os.fstat(tsvin.fileno()).st_size fpos = 0 while fpos!=fsize: tsvout.write(str(fpos)+"\n") tsvin.readline() fpos = tsvin.tell() os.rename(idxout_tmp, idxout) def read_to_character(fp, c): result = [] while True: s = fp.read(32) assert s != '' if c in s: result.append(s[: s.index(c)]) break else: result.append(s) return ''.join(result) class TSVFile(object): def __init__(self, tsv_file, generate_lineidx=False): self.tsv_file = tsv_file self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' self._fp = None self._lineidx = None # the process always keeps the process which opens the file. # If the pid is not equal to the currrent pid, we will re-open the file. self.pid = None # generate lineidx if not exist if not op.isfile(self.lineidx) and generate_lineidx: generate_lineidx(self.tsv_file, self.lineidx) def __del__(self): if self._fp: self._fp.close() def __str__(self): return "TSVFile(tsv_file='{}')".format(self.tsv_file) def __repr__(self): return str(self) def num_rows(self): self._ensure_lineidx_loaded() return len(self._lineidx) def seek(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() try: pos = self._lineidx[idx] except: logging.info('{}-{}'.format(self.tsv_file, idx)) raise self._fp.seek(pos) return [s.strip() for s in self._fp.readline().split('\t')] def seek_first_column(self, idx): self._ensure_tsv_opened() self._ensure_lineidx_loaded() pos = self._lineidx[idx] self._fp.seek(pos) return read_to_character(self._fp, '\t') def get_key(self, idx): return self.seek_first_column(idx) def __getitem__(self, index): return self.seek(index) def __len__(self): return self.num_rows() def _ensure_lineidx_loaded(self): if self._lineidx is None: logging.info('loading lineidx: {}'.format(self.lineidx)) with open(self.lineidx, 'r') as fp: self._lineidx = [int(i.strip()) for i in fp.readlines()] def _ensure_tsv_opened(self): if self._fp is None: self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() if self.pid != os.getpid(): logging.info('re-open {} because the process id changed'.format(self.tsv_file)) self._fp = open(self.tsv_file, 'r') self.pid = os.getpid() class CompositeTSVFile(): def __init__(self, file_list, seq_file, root='.'): if isinstance(file_list, str): self.file_list = load_list_file(file_list) else: assert isinstance(file_list, list) self.file_list = file_list self.seq_file = seq_file self.root = root self.initialized = False self.initialize() def get_key(self, index): idx_source, idx_row = self.seq[index] k = self.tsvs[idx_source].get_key(idx_row) return '_'.join([self.file_list[idx_source], k]) def num_rows(self): return len(self.seq) def __getitem__(self, index): idx_source, idx_row = self.seq[index] return self.tsvs[idx_source].seek(idx_row) def __len__(self): return len(self.seq) def initialize(self): ''' this function has to be called in init function if cache_policy is enabled. Thus, let's always call it in init funciton to make it simple. ''' if self.initialized: return self.seq = [] with open(self.seq_file, 'r') as fp: for line in fp: parts = line.strip().split('\t') self.seq.append([int(parts[0]), int(parts[1])]) self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] self.initialized = True def load_list_file(fname): with open(fname, 'r') as fp: lines = fp.readlines() result = [line.strip() for line in lines] if len(result) > 0 and result[-1] == '': result = result[:-1] return result