"""
dN/dS (Ka/Ks) selection analysis between two coding sequences for ToolUniverse.
Local-compute, deterministic Nei-Gojobori (1986) estimator with Jukes-Cantor
correction — the standard way to tell positive/diversifying selection (dN/dS > 1)
from purifying selection (dN/dS << 1) and near-neutral evolution (dN/dS ~ 1).
Pure Python (no dependencies), no network, no API key.
General by construction: it takes any two in-frame, codon-aligned coding
sequences (inline or single-record FASTA files) — any organisms, any genes. The
method (per-codon synonymous/non-synonymous site counting + pathway-averaged
differences + JC69 correction) is ported verbatim from the comparative-genomics
skill's validated implementation.
"""
import itertools
import math
import os
from typing import Any, Dict, Optional
from .base_tool import BaseTool
from .tool_registry import register_tool
_BASES = "TCAG"
_CODONS = [a + b + c for a in _BASES for b in _BASES for c in _BASES]
# Standard genetic code (NCBI table 1).
_AA = "FFLLSSSSYY**CC*WLLLLPPPPHHQQRRRRIIIMTTTTNNKKSSRRVVVVAAAADDEEGGGG"
_CODON_TABLE = dict(zip(_CODONS, _AA))
def _err(msg: str) -> Dict[str, Any]:
return {"status": "error", "error": msg}
def _ok(data: Dict[str, Any], **metadata) -> Dict[str, Any]:
meta = {"engine": "nei_gojobori_1986", "correction": "jukes_cantor"}
meta.update(metadata)
return {"status": "success", "data": data, "metadata": meta}
def _syn_nonsyn_sites(codon: str):
"""Synonymous (s) and non-synonymous (n) site counts for one codon (s+n=3)."""
aa = _CODON_TABLE.get(codon)
if aa is None or aa == "*":
return 0.0, 0.0
s = 0.0
for pos in range(3):
syn = 0
for base in _BASES:
if base == codon[pos]:
continue
mut = codon[:pos] + base + codon[pos + 1 :]
maa = _CODON_TABLE.get(mut)
if maa is not None and maa != "*" and maa == aa:
syn += 1
s += syn / 3.0
return s, 3.0 - s
def _path_diffs(c1: str, c2: str):
"""Avg synonymous/non-synonymous differences over all shortest mutational
pathways between two codons (Nei-Gojobori)."""
diffs = [i for i in range(3) if c1[i] != c2[i]]
if not diffs:
return 0.0, 0.0
sd_total = nd_total = 0.0
paths = 0
for order in itertools.permutations(diffs):
cur = c1
ok = True
sd = nd = 0.0
for pos in order:
nxt = cur[:pos] + c2[pos] + cur[pos + 1 :]
a1, a2 = _CODON_TABLE.get(cur), _CODON_TABLE.get(nxt)
if a1 == "*" or a2 == "*" or a1 is None or a2 is None:
ok = False
break
if a1 == a2:
sd += 1
else:
nd += 1
cur = nxt
if ok:
sd_total += sd
nd_total += nd
paths += 1
if paths == 0:
return 0.0, float(len(diffs))
return sd_total / paths, nd_total / paths
def _jukes_cantor(p: float) -> Optional[float]:
"""JC69 correction; returns None if uncorrectable (p too large)."""
if p < 0:
return 0.0
val = 1.0 - (4.0 / 3.0) * p
if val <= 0:
return None
return -0.75 * math.log(val)
def _compute_dnds(seq1: str, seq2: str) -> Any:
"""Nei-Gojobori dN/dS. Returns a result dict, or an error dict."""
seq1 = seq1.upper().replace("U", "T")
seq2 = seq2.upper().replace("U", "T")
n = min(len(seq1), len(seq2))
n -= n % 3
if n == 0:
return _err(
"sequences too short or not codon-length (need >= 3 aligned bases)."
)
S = N = Sd = Nd = 0.0
compared = 0
for i in range(0, n, 3):
c1, c2 = seq1[i : i + 3], seq2[i : i + 3]
if "-" in c1 or "-" in c2 or len(c1) < 3:
continue
s1, n1 = _syn_nonsyn_sites(c1)
s2, n2 = _syn_nonsyn_sites(c2)
S += (s1 + s2) / 2.0
N += (n1 + n2) / 2.0
sd, nd = _path_diffs(c1, c2)
Sd += sd
Nd += nd
compared += 1
if compared == 0:
return _err("no comparable codons (all gapped or stop codons).")
def _z(x): # normalise -0.0 -> 0.0
return 0.0 if (x is not None and x == 0) else x
pS = _z(Sd / S if S else 0.0)
pN = _z(Nd / N if N else 0.0)
dS = _z(_jukes_cantor(pS))
dN = _z(_jukes_cantor(pN))
omega = (dN / dS) if (dN is not None and dS not in (None, 0.0)) else None
interp = "undetermined (dS is 0 or saturated)"
if omega is not None:
if omega > 1.25:
interp = "positive (diversifying) selection (dN/dS > 1)"
elif omega < 0.5:
interp = "purifying selection / functional constraint (dN/dS << 1)"
else:
interp = "near-neutral / relaxed selection (dN/dS ~ 1)"
return {
"dN_dS": None if omega is None else round(omega, 4),
"dN": None if dN is None else round(dN, 4),
"dS": None if dS is None else round(dS, 4),
"pN": round(pN, 4),
"pS": round(pS, 4),
"N_sites": round(N, 2),
"S_sites": round(S, 2),
"Nd": round(Nd, 2),
"Sd": round(Sd, 2),
"codons_compared": compared,
"interpretation": interp,
}
def _read_fasta(path: str) -> Any:
"""Read the single sequence from a FASTA file (concatenating record lines)."""
path = os.path.expanduser(str(path).strip())
if not os.path.isfile(path):
return _err(f"FASTA not found: {path}")
try:
seq = []
with open(path) as fh:
for line in fh:
line = line.strip()
if line and not line.startswith(">"):
seq.append(line)
if not seq:
return _err(f"no sequence found in {path}")
return "".join(seq)
except Exception as e: # pragma: no cover - defensive
return _err(f"failed to read FASTA {path}: {e}")