Source code for kinactive.db

"""
A :class:`DB` class for the PK data collection creation and io.
"""
import logging
import operator as op
import typing as t
from collections import abc
from io import StringIO
from itertools import chain
from pathlib import Path
from random import sample

import lXtractor.core.chain.tree as chain_tree
import lXtractor.core.config as lx_cfg
import pandas as pd
from lXtractor.core.chain import (
    Chain,
    ChainIO,
    ChainInitializer,
    ChainList,
    ChainSequence,
    ChainStructure,
)
from lXtractor.ext import PDB, PyHMMer, SIFTS, fetch_uniprot, filter_by_method
from lXtractor.util import get_files, read_fasta, write_fasta
from more_itertools import ilen, consume, unzip
from toolz import curry, groupby, itemmap, keyfilter, keymap
from tqdm.auto import tqdm

from kinactive.config import DBConfig, DumpNames

T = t.TypeVar("T")
CT_: t.TypeAlias = Chain | ChainSequence | ChainStructure
LOGGER = logging.getLogger(__name__)


# TODO: some object IDs are duplicated:
# This stems from the issue of chimeric sequences.
# However, such sequences should not pass the filtering when canonical/structure
# seqs are compared.
# ['ChainStructure(PK_1|10-260<-(5UFU:A|1-375))',
#  'ChainStructure(PK_1|136-354<-(7APJ:A|1-385))',
#  'ChainStructure(PK_1|38-323<-(3U87:A|1-334))',
#  'ChainStructure(PK_1|38-323<-(3U87:B|1-334))']
# TODO: include UniProt metadata


def _get_remaining(names: abc.Iterable[str], dir_: Path) -> set[str]:
    existing = {x.stem for x in get_files(dir_).values()}
    return set(names) - existing


def _is_sequence_of_chain_seqs(
        s: abc.Sequence[t.Any],
) -> t.TypeGuard[abc.Sequence[ChainSequence]]:
    return all(isinstance(x, ChainSequence) for x in s)


def _stage_chain_init(
        seq: T, pdb_chains: abc.Iterable[str], pdb_dir: Path, fmt: str
) -> tuple[T, list[tuple[Path, list[str]]]]:
    id2chains = groupby(op.itemgetter(0), map(lambda x: x.split(":"), pdb_chains))
    path2chains = itemmap(
        lambda x: (pdb_dir / f"{x[0]}.{fmt}", list(map(op.itemgetter(1), x[1]))),
        id2chains,
    )
    assert all(x.exists() for x in path2chains)
    return seq, list(path2chains.items())


def _filter_by_size(
        structures: list[ChainStructure], cfg: DBConfig
) -> list[ChainStructure]:
    return [s for s in structures if len(s.seq) >= cfg.pdb_str_min_size]


def _rm_solvent(structures: list[ChainStructure]) -> list[ChainStructure]:
    return [s.rm_solvent() for s in structures]


def _drop_all_na(df: pd.DataFrame) -> pd.DataFrame:
    to_drop = [c for c in df.columns if df[c].isna().sum() == len(df)]
    return df.drop(columns=to_drop)


def _split_summary(df: pd.DataFrame) -> tuple[pd.DataFrame, ...]:
    idx_parent = df.ParentID.isna()
    idx_str = df.Structure == True
    splits = (
        df[idx_parent & ~idx_str],
        df[idx_parent & idx_str],
        df[~idx_parent & ~idx_str],
        df[~idx_parent & idx_str],
    )
    return tuple(map(_drop_all_na, splits))


[docs]class DB: """ An object encapsulating methods for building/saving/loading an lXtractor "database" -- a collection of :class:`Chain`'s. """
[docs] def __init__(self, cfg: DBConfig = DBConfig()): self.cfg = cfg self._sifts: SIFTS | None = None self._pdb: PDB | None = None self._pk_hmm: PyHMMer | None = None self._chains: ChainList[Chain] = ChainList([])
@property def chains(self) -> ChainList[Chain]: """ :return: Currently stored chains. """ return self._chains def _load_sifts(self, overwrite: bool = False) -> SIFTS: if self._sifts is not None and not overwrite: sifts = self._sifts else: sifts = SIFTS(load_segments=False, load_id_mapping=True) if sifts.id_mapping is None: LOGGER.info("Initializing SIFTS for the first time.") sifts.fetch() sifts.parse() return sifts def _load_pdb(self, overwrite: bool = False) -> PDB: if self._pdb is None or overwrite: self._pdb = PDB( self.cfg.max_fetch_trials, self.cfg.pdb_num_fetch_threads, self.cfg.verbose, ) return self._pdb def _load_pk_hmm(self, overwrite: bool = False) -> PyHMMer: if self._pk_hmm is None or overwrite: self._pk_hmm = PyHMMer(self.cfg.profile, bit_cutoffs="trusted") return self._pk_hmm def _fetch_seqs(self, ids: abc.Iterable[str]): raw_seqs = fetch_uniprot( ids, num_threads=self.cfg.uniprot_num_fetch_threads, chunk_size=self.cfg.uniprot_chunk_size, verbose=self.cfg.verbose, ) parsed_seqs: abc.Iterable[tuple[str, str]] = read_fasta(StringIO(raw_seqs)) if self.cfg.verbose: parsed_seqs = tqdm(parsed_seqs, desc="Saving fetched sequences") for header, seq in parsed_seqs: id_ = header.split("|")[1] write_fasta([(header, seq)], self.cfg.seq_dir / f"{id_}.fasta") def _read_seqs(self, ids: abc.Iterable[str]) -> ChainList[Chain]: files = keymap(lambda x: x.removesuffix(".fasta"), get_files(self.cfg.seq_dir)) matching = set(files) & set(ids) paths = keyfilter(lambda x: x in matching, files).values() init = ChainInitializer(verbose=self.cfg.verbose) chains = list(init.from_iterable(paths)) assert _is_sequence_of_chain_seqs(chains), "correct types are returned" return ChainList(map(Chain, chains)) def _get_sifts_xray(self) -> list[str]: sifts = self._load_sifts() pdb = self._load_pdb() return filter_by_method(sifts.pdb_ids, pdb=pdb, method="X-ray")
[docs] def build( self, uniprot_ids: abc.Collection[str] | None = None, pdb_chain_ids: abc.Collection[str] | None = None, n_domains: int = 0, ) -> ChainList[Chain]: """ Build a new lXt-PK data collection. :param uniprot_ids: An optional list of UniProt IDs to restrict the db to. :param pdb_chain_ids: An optional collection of PDB chains to restrict the db to. Format: "{PDB_ID}:{ChainID}". :param n_domains: Use n random sequence domains. It is helpful for testing the pipeline. :return: A :class:`ChainList` of :class:`Chain` objects having at least one child PK domain with at least one PK domain structure passing the filtering thresholds. """ def match_seq(s: ChainStructure) -> ChainStructure: s.seq.match("seq1", "seq1_canonical", as_fraction=True, save=True) return s def filter_domain_str_by_canon_seq_match(c: Chain) -> Chain: c.transfer_seq_mapping( lx_cfg.SeqNames.seq1, map_name_in_other="seq1_canonical" ) match_name = "Match_seq1_seq1_canonical" c = c.apply_structures(match_seq).filter_structures( lambda s: ( len(s.seq) >= self.cfg.pk_min_str_domain_size and s.seq.meta[match_name] >= self.cfg.pk_min_str_seq_match ) ) return c # 0. Init directories for dir_ in [ self.cfg.target_dir, self.cfg.pdb_dir, self.cfg.seq_dir, self.cfg.pdb_dir_info, ]: if dir_ is not None: dir_.mkdir(exist_ok=True, parents=True) sifts = self._load_sifts() pdb = self._load_pdb() # Fetch SIFTS UniProt seqs if uniprot_ids: ids = list(filter(lambda x: x in uniprot_ids, sifts.uniprot_ids)) LOGGER.info( f"Filtered to {len(ids)} out of {len(sifts.uniprot_ids)} initial IDs " f"contained in SIFTS using {len(uniprot_ids)} reference IDs." ) missing = set(uniprot_ids) - set(ids) if missing: LOGGER.warning(f"{len(missing)} IDs were missing in SIFTS: {missing}") else: ids = sifts.uniprot_ids fetch_ids = _get_remaining(ids, self.cfg.seq_dir) LOGGER.info(f"{len(fetch_ids)} remaining sequences to fetch.") self._fetch_seqs(fetch_ids) # Read seqs = self._read_seqs(ids) LOGGER.info(f"Got {len(seqs)} seqs from {self.cfg.seq_dir}") # Filter sequences by size min_size, max_size = self.cfg.min_seq_size, self.cfg.max_seq_size seqs = seqs.filter(lambda s: min_size <= len(s.seq) <= max_size) LOGGER.info(f"Filtered to {len(seqs)} seqs in [{min_size}, {max_size}]") # Annotate PK domains and filter seqs to the annotated ones pk_hmm = self._load_pk_hmm() ann: abc.Iterable[Chain] = pk_hmm.annotate( seqs, new_map_name=self.cfg.pk_map_name, min_score=self.cfg.pk_min_score, min_size=self.cfg.pk_min_seq_domain_size, min_cov_hmm=self.cfg.pk_min_cov_hmm, min_cov_seq=self.cfg.pk_min_cov_seq, ) if self.cfg.verbose: ann = tqdm(ann, desc="Annotating sequence domains") annotated_num = sum(1 for _ in ann) seqs = seqs.filter(lambda x: len(x.children) > 0) LOGGER.info(f"Found {annotated_num} PK domains within {len(seqs)} seqs.") if n_domains: seqs = ChainList(sample(seqs, n_domains)) LOGGER.info(f"Sampled to {len(seqs)} random initial domains.") # Get UniProt IDs and corresponding PDB Chains uni_ids = [x.id.split("|")[1] for x in seqs] pdb_chains = [x for x in map(sifts.map_id, uni_ids) if x is not None] # Filter PDB IDs to provided list if pdb_chain_ids: num_init = ilen(chain.from_iterable(pdb_chains)) pdb_chains = [ list(filter(lambda x: x in pdb_chain_ids, chain_group)) for chain_group in pdb_chains ] LOGGER.info( f"Filtered to {ilen(chain.from_iterable(pdb_chains))} out " f"of {num_init} initially mapped PDB chains " f"({len(pdb_chain_ids)} reference IDs were provided for filtering)." ) filtered_pairs = list( filter(lambda x: len(x[1]) > 1, zip(uni_ids, pdb_chains, strict=True)) ) LOGGER.info( f"Filtered to {len(filtered_pairs)} sequences mapped to at least one " f"PDB chain." ) if not filtered_pairs: LOGGER.warning("All sequences were filtered out. Terminating...") return ChainList([]) uni_ids, pdb_chains = map(list, unzip(filtered_pairs)) # Filter PDB IDs to X-ray structures pdb_ids = {x.split(":")[0] for x in chain.from_iterable(pdb_chains)} LOGGER.info(f"Fetching info for {len(pdb_ids)} PDB IDs.") xray_pdb_ids = set( filter_by_method( pdb_ids, pdb=pdb, dir_=self.cfg.pdb_dir_info, method="X-ray" ) ) LOGGER.info( f"Filtered to {len(xray_pdb_ids)} X-ray PDB IDs out of {len(pdb_ids)}." ) pdb_chains = [ [c for c in cs if c.split(":")[0] in xray_pdb_ids] for cs in pdb_chains ] # Fetch X-ray structures LOGGER.info(f"Fetching {len(xray_pdb_ids)} X-ray structures") pdb.fetch_structures(xray_pdb_ids, dir_=self.cfg.pdb_dir, fmt=self.cfg.pdb_fmt) # Init Chain objects seq2pdb = dict( _stage_chain_init(seq, c, self.cfg.pdb_dir, self.cfg.pdb_fmt) for seq, c in zip(seqs, pdb_chains) if len(c) > 0 ) init = ChainInitializer(tolerate_failures=True, verbose=self.cfg.verbose) chains: ChainList[Chain] = ChainList( init.from_mapping( seq2pdb, val_callbacks=[_rm_solvent, curry(_filter_by_size)(cfg=self.cfg)], num_proc_read_str=self.cfg.init_cpus, num_proc_map_numbering=self.cfg.init_map_numbering_cpus, add_to_children=True, ) ).filter(lambda c: len(c.structures) > 0) LOGGER.info(f"Initialized {len(chains)} `Chain` objects.") num_init = ilen(chains.collapse_children().iter_structures()) chains = chains.apply( lambda c: c.apply_children(filter_domain_str_by_canon_seq_match) ) num_curr = ilen(chains.collapse_children().iter_structures()) LOGGER.info( f"Filtered to {num_curr} out of {num_init} domain structures " f"having >={self.cfg.pk_min_str_domain_size} extracted domain size " f"and >={self.cfg.pk_min_str_seq_match} canonical seq match fraction." ) num_init = len(chains.collapse_children()) chains = chains.apply( lambda c: c.filter_children(lambda x: len(x.structures) > 0) ) num_curr = len(chains.collapse_children()) LOGGER.info( f"Filtered to {num_curr} out of {num_init} domains with " "at least one valid structures." ) num_init = len(chains) chains = chains.filter(lambda c: len(c.children) > 0) LOGGER.info( f"Filtered to {len(chains)} chains out of {num_init} " "with at least one extracted domains." ) for c in chains.collapse_children(): c.transfer_seq_mapping(self.cfg.pk_map_name) self._chains = chains return chains
[docs] def save( self, dest: Path | None = None, chains: abc.Iterable[Chain] | None = None, *, overwrite: bool = False, summary: bool = True, ) -> None: """ Save DB sequence to file system. :param dest: Destination path to write seqs into. :param chains: Manual chains input to save. If ``None``, will use :attr:`chains`. :param overwrite: Overwrite existing data in ``dest``. :param summary: Compose and save summaries to ``dest``. :return: An iterator over paths of successfully saved chains. Consume to trigger saving. """ chains = chains or self.chains dest = dest or self.cfg.target_dir if dest.exists(): assert dest.is_dir(), "Path to directory" if not overwrite: files = get_files(dest) assert len(files) == 0, "Existing dir is not empty" # dest.mkdir(exist_ok=True, parents=True) io = ChainIO( num_proc=self.cfg.io_cpus, verbose=self.cfg.verbose, tolerate_failures=False ) consume(io.write(chains, base=dest)) if summary: summary = self.chains.summary(children=True, structures=True) for df, name in zip(_split_summary(summary), DumpNames.summary_file_names): df.to_csv(dest / name, index=False) LOGGER.info(f"Saved summary file {name} to {dest}")
[docs] def load( self, dump: Path | abc.Iterable[Path] ) -> ChainList[Chain]: """ Load prepared db. :param dump: Path with dumped :class:`Chain`s. :return: A chain list with initialized :class:`Chain`s. """ if isinstance(dump, Path): dump = dump.glob("*") dump = list(dump) LOGGER.info(f"Got {len(dump)} initial paths to read") io = ChainIO( num_proc=self.cfg.io_cpus, verbose=self.cfg.verbose, tolerate_failures=True, ) chain_read_it = io.read_chain( dump, callbacks=[chain_tree.recover], search_children=True ) chains = ChainList(chain_read_it) chains = chains.apply( chain_tree.recover, verbose=self.cfg.verbose, desc="Recovering ancestry for sequences and structures", ) if len(chains) > 0: LOGGER.info(f"Parsed {len(chains)} `Chain`s") self._chains = chains else: LOGGER.warning(f"Found no `Chain`s in {dump}") self._chains = chains return chains
[docs] def fetch(self): """ Fetch an already prepared data collection. :return: """ # TODO: implement fetching and unpacking raise NotImplementedError
if __name__ == "__main__": raise RuntimeError