"""
A :class:`DB` class for the PK data collection creation and io.
"""
import json
import logging
import operator as op
import typing as t
from collections import abc
from io import StringIO
from itertools import chain
from pathlib import Path
from random import sample
import pandas as pd
from lXtractor.chain import (
Chain,
ChainIO,
ChainInitializer,
ChainList,
ChainSequence,
ChainStructure,
recover,
)
from lXtractor.core.config import DefaultConfig
from lXtractor.core.segment import resolve_overlaps
from lXtractor.ext import PDB, PyHMMer, SIFTS, fetch_uniprot, filter_by_method
from lXtractor.util import get_files, read_fasta, write_fasta
from more_itertools import ilen, consume, unzip
from toolz import curry, groupby, itemmap, keyfilter, keymap
from tqdm.auto import tqdm
from kinactive.base import TK_PROFILE_PATH, PK_PROFILE_PATH
from kinactive.config import DBConfig, DumpNames
T = t.TypeVar("T")
CT_: t.TypeAlias = Chain | ChainSequence | ChainStructure
LOGGER = logging.getLogger(__name__)
# Change primary polymer type to the expected protein
DefaultConfig["structure"]["primary_pol_type"] = "p"
# TODO: some object IDs are duplicated:
# This stems from the issue of chimeric sequences.
# However, such sequences should not pass the filtering when canonical/structure
# seqs are compared.
# ['ChainStructure(PK_1|10-260<-(5UFU:A|1-375))',
# 'ChainStructure(PK_1|136-354<-(7APJ:A|1-385))',
# 'ChainStructure(PK_1|38-323<-(3U87:A|1-334))',
# 'ChainStructure(PK_1|38-323<-(3U87:B|1-334))']
# TODO: include UniProt metadata
# TODO: an option to patch PDB sequences
def _get_remaining(names: abc.Iterable[str], dir_: Path) -> set[str]:
existing = {x.stem for x in get_files(dir_).values()}
return set(names) - existing
def _is_sequence_of_chain_seqs(
s: abc.Sequence[t.Any],
) -> t.TypeGuard[abc.Sequence[ChainSequence]]:
return all(isinstance(x, ChainSequence) for x in s)
def _stage_chain_init(
seq: T, pdb_chains: abc.Iterable[str], pdb_dir: Path, fmt: str
) -> tuple[T, list[tuple[Path, list[str]]]]:
id2chains = groupby(op.itemgetter(0), map(lambda x: x.split(":"), pdb_chains))
path2chains = itemmap(
lambda x: (pdb_dir / f"{x[0]}.{fmt}", list(map(op.itemgetter(1), x[1]))),
id2chains,
)
assert all(x.exists() for x in path2chains)
return seq, list(path2chains.items())
def _filter_by_size(
structures: list[ChainStructure], cfg: DBConfig
) -> list[ChainStructure]:
return [s for s in structures if len(s.seq) >= cfg.pdb_str_min_size]
def _rm_solvent(structures: list[ChainStructure]) -> list[ChainStructure]:
return [s.rm_solvent() for s in structures]
def _drop_all_na(df: pd.DataFrame) -> pd.DataFrame:
to_drop = [c for c in df.columns if df[c].isna().sum() == len(df)]
return df.drop(columns=to_drop)
def _split_summary(df: pd.DataFrame) -> tuple[pd.DataFrame, ...]:
idx_parent = df.ParentID.isna()
idx_str = df.Structure == True
splits = (
df[idx_parent & ~idx_str],
df[idx_parent & idx_str],
df[~idx_parent & ~idx_str],
df[~idx_parent & idx_str],
)
return tuple(map(_drop_all_na, splits))
[docs]
class DB:
"""
An object encapsulating methods for building/saving/loading an lXtractor
"database" -- a collection of :class:`Chain`'s.
"""
[docs]
def __init__(self, cfg: DBConfig = DBConfig()):
self.cfg = cfg
self._sifts: SIFTS | None = None
self._pdb: PDB | None = None
self._pk_hmm: PyHMMer | None = None
self._chains: ChainList[Chain] = ChainList([])
@property
def chains(self) -> ChainList[Chain]:
"""
:return: Currently stored chains.
"""
return self._chains
def _load_sifts(self, overwrite: bool = False) -> SIFTS:
if self._sifts is not None and not overwrite:
sifts = self._sifts
else:
sifts = SIFTS(load_segments=False, load_id_mapping=True)
if sifts.id_mapping is None:
LOGGER.info("Initializing SIFTS for the first time.")
sifts.fetch()
sifts.parse()
return sifts
def _load_pdb(self, overwrite: bool = False) -> PDB:
if self._pdb is None or overwrite:
self._pdb = PDB(
self.cfg.max_fetch_trials,
self.cfg.pdb_num_fetch_threads,
self.cfg.verbose,
)
return self._pdb
def _load_pk_hmm(self, overwrite: bool = False) -> PyHMMer:
if self._pk_hmm is None or overwrite:
self._pk_hmm = PyHMMer(self.cfg.profile)
return self._pk_hmm
def _load_tk2pk(self) -> dict[int, int]:
with self.cfg.tk2pk.open() as f:
return keymap(int, json.load(f))
def _fetch_seqs(self, ids: abc.Iterable[str]):
raw_seqs = fetch_uniprot(
ids,
num_threads=self.cfg.uniprot_num_fetch_threads,
chunk_size=self.cfg.uniprot_chunk_size,
verbose=self.cfg.verbose,
)
parsed_seqs: abc.Iterable[tuple[str, str]] = read_fasta(StringIO(raw_seqs))
if self.cfg.verbose:
parsed_seqs = tqdm(parsed_seqs, desc="Saving fetched sequences")
for header, seq in parsed_seqs:
id_ = header.split("|")[1]
write_fasta([(header, seq)], self.cfg.seq_dir / f"{id_}.fasta")
def _read_seqs(self, ids: abc.Iterable[str]) -> ChainList[Chain]:
files = keymap(lambda x: x.removesuffix(".fasta"), get_files(self.cfg.seq_dir))
matching = set(files) & set(ids)
paths = keyfilter(lambda x: x in matching, files).values()
init = ChainInitializer(verbose=self.cfg.verbose)
chains = list(init.from_iterable(paths))
assert _is_sequence_of_chain_seqs(chains), "correct types are returned"
return ChainList(map(Chain, chains))
def _get_sifts_xray(self) -> list[str]:
sifts = self._load_sifts()
pdb = self._load_pdb()
return filter_by_method(sifts.pdb_ids, pdb=pdb, method="X-ray")
[docs]
def obtain_sifts_seqs(
self, uniprot_ids: abc.Sequence[str] | None = None
) -> ChainList[Chain]:
sifts = self._load_sifts()
if uniprot_ids:
ids = list(filter(lambda x: x in uniprot_ids, sifts.uniprot_ids))
LOGGER.info(
f"Filtered to {len(ids)} out of {len(sifts.uniprot_ids)} initial IDs "
f"contained in SIFTS using {len(uniprot_ids)} reference IDs."
)
missing = set(uniprot_ids) - set(ids)
if missing:
LOGGER.warning(f"{len(missing)} IDs were missing in SIFTS: {missing}")
else:
ids = sifts.uniprot_ids
fetch_ids = _get_remaining(ids, self.cfg.seq_dir)
LOGGER.info(f"{len(fetch_ids)} remaining sequences to fetch.")
self._fetch_seqs(fetch_ids)
# Read
seqs = self._read_seqs(ids)
LOGGER.info(f"Got {len(seqs)} seqs from {self.cfg.seq_dir}")
# Filter sequences by size
min_size, max_size = self.cfg.min_seq_size, self.cfg.max_seq_size
seqs = seqs.filter(lambda s: min_size <= len(s.seq) <= max_size)
LOGGER.info(f"Filtered to {len(seqs)} seqs in [{min_size}, {max_size}]")
return seqs
[docs]
def discover_domains(self, seqs: ChainList[CT_]) -> ChainList[CT_]:
def transfer_pk_map(cs: CT_) -> CT_:
children = cs.children
tk_children = children.filter(lambda x: "TK" in x.name)
for c in tk_children:
tk_df = c.seq.as_df()
tk_df["PK"] = tk_df["TK"].map(tk2pk)
c.seq["PK"] = tk_df["PK"].tolist()
return cs
@curry
def get_field(seq: ChainSequence, contains: str) -> t.Any:
key = next(filter(lambda x: contains in x, seq.meta))
return seq.meta[key]
def filter_domains(cs: CT_) -> CT_:
children = cs.children.filter(
lambda x: len(x.seq) >= self.cfg.pk_min_seq_domain_size
and float(get_field(x.seq, "cov_hmm")) >= self.cfg.pk_min_cov_hmm
and float(get_field(x.seq, "cov_seq")) >= self.cfg.pk_min_cov_seq
)
if len(children) == 0:
cs.children = ChainList([])
return cs
non_overlapping = resolve_overlaps(
children.sequences, value_fn=get_field(contains="score")
)
non_overlapping_ids = [s.id for s in non_overlapping]
cs.children = children.filter(lambda x: x.seq.id in non_overlapping_ids)
return cs
tk2pk = self._load_tk2pk()
tk_prof = PyHMMer(TK_PROFILE_PATH)
pk_prof = PyHMMer(PK_PROFILE_PATH)
LOGGER.info("Annotating domains")
consume(
pk_prof.annotate(
seqs,
min_size=50,
min_score=self.cfg.pk_min_score,
new_map_name="PK",
)
)
consume(
tk_prof.annotate(
seqs,
min_size=50,
min_score=self.cfg.pk_min_score,
new_map_name="TK",
)
)
seqs = seqs.filter(lambda x: len(x.children) > 0)
LOGGER.info(f"Discovered {len(seqs)} sequences with domain hits")
tk_hits = seqs.collapse_children().filter(lambda x: "TK" in x.name)
pk_hits = seqs.collapse_children().filter(lambda x: "PK" in x.name)
LOGGER.info(f"Initial TK hits: {len(tk_hits)}")
LOGGER.info(f"Initial PK hits: {len(pk_hits)}")
LOGGER.info("Transferring PK profile maps to TK hits")
seqs = (
seqs.apply(transfer_pk_map)
.apply(filter_domains)
.filter(lambda x: len(x.children) > 0)
)
LOGGER.info(
f"Filtered to {len(seqs)} sequences with at least one valid "
f"domain with conforming to config criteria."
)
tk_hits = seqs.collapse_children().filter(lambda x: "TK" in x.name)
pk_hits = seqs.collapse_children().filter(lambda x: "PK" in x.name)
LOGGER.info(f"Final TK hits: {len(tk_hits)}")
LOGGER.info(f"Final PK hits: {len(pk_hits)}")
return seqs
[docs]
def build(
self,
uniprot_ids: abc.Collection[str] | None = None,
pdb_chain_ids: abc.Collection[str] | None = None,
n_domains: int = 0,
) -> ChainList[Chain]:
"""
Build a new lXt-PK data collection.
:param uniprot_ids: An optional list of UniProt IDs to restrict
the db to.
:param pdb_chain_ids: An optional collection of PDB chains to restrict
the db to. Format: "{PDB_ID}:{ChainID}".
:param n_domains: Use n random sequence domains. It is helpful for
testing the pipeline.
:return: A :class:`ChainList` of :class:`Chain` objects having at least
one child PK domain with at least one PK domain structure passing
the filtering thresholds.
"""
def match_seq(s: ChainStructure) -> ChainStructure:
s.seq.match("seq1", "seq1_canonical", as_fraction=True, save=True)
return s
def accept_domain_structure(c: Chain) -> Chain:
c.transfer_seq_mapping(
DefaultConfig["mapnames"]["seq1"], map_name_in_other="seq1_canonical"
)
match_name = "Match_seq1_seq1_canonical"
c = c.apply_structures(match_seq).filter_structures(
lambda s: (
len(s.seq) >= self.cfg.pk_min_str_domain_size
and s.seq.meta[match_name] >= self.cfg.pk_min_str_seq_match
)
)
return c
def filter_structures(c: Chain) -> Chain:
parent_ids = [x.parent.id for x in c.children.structures]
c.structures = c.structures.filter(lambda x: x.id in parent_ids)
return c
# 0. Init directories
for dir_ in [
self.cfg.target_dir,
self.cfg.pdb_dir,
self.cfg.seq_dir,
self.cfg.pdb_dir_info,
]:
if dir_ is not None:
dir_.mkdir(exist_ok=True, parents=True)
sifts = self._load_sifts()
pdb = self._load_pdb()
# Fetch SIFTS UniProt seqs
seqs = self.obtain_sifts_seqs(uniprot_ids)
# Annotate PK domains and filter seqs to the annotated ones
seqs = self.discover_domains(seqs)
if n_domains:
seqs = ChainList(sample(seqs, n_domains))
LOGGER.info(f"Sampled to {len(seqs)} random initial domains.")
# Get UniProt IDs and corresponding PDB Chains
uni2seq = {s.id.split("|")[1]: s for s in seqs}
uni_ids = list(uni2seq)
pdb_chains = [x for x in map(sifts.map_id, uni_ids) if x is not None]
# Filter PDB IDs to provided list
if pdb_chain_ids:
num_init = ilen(chain.from_iterable(pdb_chains))
pdb_chains = [
list(filter(lambda x: x in pdb_chain_ids, chain_group))
for chain_group in pdb_chains
]
LOGGER.info(
f"Filtered to {ilen(chain.from_iterable(pdb_chains))} out "
f"of {num_init} initially mapped PDB chains "
f"({len(pdb_chain_ids)} reference IDs were provided for filtering)."
)
filtered_pairs = list(
filter(lambda x: len(x[1]) > 0, zip(uni_ids, pdb_chains, strict=True))
)
LOGGER.info(
f"Filtered to {len(filtered_pairs)} sequences mapped to at least one "
f"PDB chain."
)
if not filtered_pairs:
LOGGER.warning("All sequences were filtered out. Terminating...")
return ChainList([])
uni_ids, pdb_chains = map(list, unzip(filtered_pairs))
# Filter PDB IDs to X-ray structures
pdb_ids = {x.split(":")[0] for x in chain.from_iterable(pdb_chains)}
LOGGER.info(f"Fetching info for {len(pdb_ids)} PDB IDs.")
xray_pdb_ids = set(
filter_by_method(
pdb_ids, pdb=pdb, dir_=self.cfg.pdb_dir_info, method="X-ray"
)
)
LOGGER.info(
f"Filtered to {len(xray_pdb_ids)} X-ray PDB IDs out of {len(pdb_ids)}."
)
pdb_chains = [
[c for c in cs if c.split(":")[0] in xray_pdb_ids] for cs in pdb_chains
]
# Fetch X-ray structures
LOGGER.info(f"Fetching {len(xray_pdb_ids)} X-ray structures")
pdb.fetch_structures(xray_pdb_ids, dir_=self.cfg.pdb_dir, fmt=self.cfg.pdb_fmt)
# Init Chain objects
seq2pdb = dict(
_stage_chain_init(
uni2seq[seq_id], str_ids, self.cfg.pdb_dir, self.cfg.pdb_fmt
)
for seq_id, str_ids in zip(uni_ids, pdb_chains)
if len(str_ids) > 0
)
init = ChainInitializer(
tolerate_failures=self.cfg.init_tolerate_failures, verbose=self.cfg.verbose
)
chains: ChainList[Chain] = ChainList(
init.from_mapping(
seq2pdb,
val_callbacks=[_rm_solvent, curry(_filter_by_size)(cfg=self.cfg)],
num_proc_read_str=self.cfg.init_cpus,
num_proc_map_numbering=self.cfg.init_map_numbering_cpus,
num_proc_add_structure=self.cfg.init_add_structure_cpus,
add_to_children=True,
)
).filter(lambda c: len(c.structures) > 0)
LOGGER.info(f"Initialized {len(chains)} `Chain` objects.")
num_init = len(chains.collapse_children().structures)
chains = chains.apply(lambda c: c.apply_children(accept_domain_structure))
chains = chains.apply(filter_structures)
num_curr = len(chains.collapse_children().structures)
LOGGER.info(
f"Filtered to {num_curr} out of {num_init} domain structures "
f"having >={self.cfg.pk_min_str_domain_size} extracted domain size "
f"and >={self.cfg.pk_min_str_seq_match} canonical seq match fraction."
)
num_init = len(chains.collapse_children())
chains = chains.apply(
lambda c: c.filter_children(lambda x: len(x.structures) > 0)
)
num_curr = len(chains.collapse_children())
LOGGER.info(
f"Filtered to {num_curr} out of {num_init} domains with "
"at least one valid structure."
)
num_init = len(chains)
chains = chains.filter(lambda c: len(c.children) > 0)
LOGGER.info(
f"Filtered to {len(chains)} chains out of {num_init} "
"with at least one extracted domains."
)
for c in chains.collapse_children():
c.transfer_seq_mapping(self.cfg.pk_map_name)
# c.seq.children = ChainList([])
self._chains = chains
return chains
[docs]
def save(
self,
dest: Path | None = None,
chains: abc.Iterable[Chain] | None = None,
*,
overwrite: bool = False,
summary: bool = True,
) -> None:
"""
Save DB sequence to file system.
:param dest: Destination path to write seqs into.
:param chains: Manual chains input to save. If ``None``, will use
:attr:`chains`.
:param overwrite: Overwrite existing data in ``dest``.
:param summary: Compose and save summaries to ``dest``.
:return: An iterator over paths of successfully saved chains. Consume
to trigger saving.
"""
chains = chains or self.chains
dest = dest or self.cfg.target_dir
if dest.exists():
assert dest.is_dir(), "Path to directory"
if not overwrite:
files = get_files(dest)
assert len(files) == 0, "Existing dir is not empty"
# dest.mkdir(exist_ok=True, parents=True)
io = ChainIO(
num_proc=self.cfg.io_cpus, verbose=self.cfg.verbose, tolerate_failures=False
)
consume(io.write(chains, base=dest, str_fmt=self.cfg.pdb_fmt))
if summary:
summary = self.chains.summary(children=True, structures=True)
for df, name in zip(_split_summary(summary), DumpNames.summary_file_names):
df.to_csv(dest / name, index=False)
LOGGER.info(f"Saved summary file {name} to {dest}")
@staticmethod
def _construct_paths(
paths: abc.Iterable[Path],
domains: bool,
structures: bool,
):
if domains:
paths = chain.from_iterable(p.glob("segments/*") for p in paths)
if structures:
paths = chain.from_iterable(p.glob("structures/*") for p in paths)
return paths
[docs]
def load(
self,
dump: Path | abc.Iterable[Path],
domains: bool = True,
sequences: bool = False,
structures: bool = False,
structures_sequences: bool = False,
) -> ChainList[Chain] | ChainList[ChainStructure] | ChainList[ChainSequence]:
"""
Load prepared db.
:param dump: Path with dumped :class:`Chain`s.
:param domains: Load domains without loading parent chains.
:param sequences: Load only canonical sequences.
:param structures: Load structures without loading canonical sequences.
:param structures_sequences: Load structure sequences without loading
structures.
:return: A chain list with initialized :class:`Chain`s.
"""
if isinstance(dump, Path):
dump = dump.glob("*")
dump = list(
self._construct_paths(dump, domains, structures or structures_sequences)
)
LOGGER.info(f"Got {len(dump)} initial paths to read")
io = ChainIO(self.cfg.io_cpus, self.cfg.verbose)
if structures:
loader = io.read_chain_str
elif sequences or structures_sequences:
loader = io.read_chain_seq
else:
loader = io.read_chain
chains = ChainList(
loader(dump, callbacks=[recover], search_children=not domains)
)
# io = ChainIO(
# num_proc=self.cfg.io_cpus,
# verbose=self.cfg.verbose,
# tolerate_failures=True,
# )
# chain_read_it = io.read_chain(
# dump, callbacks=[chain_tree.recover], search_children=True
# )
#
# chains = ChainList(chain_read_it)
#
# chains = chains.apply(
# chain_tree.recover,
# verbose=self.cfg.verbose,
# desc="Recovering ancestry for sequences and structures",
# )
# chains = read_chains(
# dump,
# children=True,
# seq_cfg=ChainIOConfig(verbose=self.cfg.verbose),
# str_cfg=ChainIOConfig(verbose=self.cfg.verbose, num_proc=self.cfg.io_cpus),
# )
# chains = chains.apply(
# chain_tree.recover,
# verbose=self.cfg.verbose,
# desc="Recovering ancestry for sequences and structures",
# )
if len(chains) > 0:
LOGGER.info(f"Parsed {len(chains)} `Chain`s")
self._chains = chains
else:
LOGGER.warning(f"Found no `Chain`s in {dump}")
self._chains = chains
return chains
if __name__ == "__main__":
raise RuntimeError