|
|
|
|
|
import gradio as gr |
|
|
import joblib |
|
|
import numpy as np |
|
|
from collections import Counter |
|
|
from typing import List |
|
|
import os |
|
|
|
|
|
|
|
|
BASES = ['A', 'T', 'C', 'G'] |
|
|
|
|
|
def kmer_counts(seq: str, k=3): |
|
|
seq = seq.strip().upper() |
|
|
counts = Counter() |
|
|
if len(seq) < k: |
|
|
return counts |
|
|
for i in range(len(seq) - k + 1): |
|
|
counts[seq[i:i+k]] += 1 |
|
|
return counts |
|
|
|
|
|
def vectorize_single(seq: str, vocab: List[str], k=3): |
|
|
X = np.zeros((1, len(vocab)), dtype=float) |
|
|
c = kmer_counts(seq, k) |
|
|
for j, kmer in enumerate(vocab): |
|
|
X[0, j] = c.get(kmer, 0) |
|
|
return X |
|
|
|
|
|
|
|
|
MODEL_PATH = "mutation_model.joblib" |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
raise FileNotFoundError( |
|
|
f"⚠️ Model file '{MODEL_PATH}' not found. " |
|
|
"Please upload 'mutation_model.joblib' along with this app." |
|
|
) |
|
|
|
|
|
model, vocab = joblib.load(MODEL_PATH) |
|
|
|
|
|
|
|
|
def predict_sequence(sequence: str): |
|
|
if not sequence or len(sequence.strip()) < 3: |
|
|
return {"error": "Please enter a valid DNA sequence (≥3 bases)."} |
|
|
|
|
|
X = vectorize_single(sequence, vocab=vocab, k=3) |
|
|
pred = model.predict(X)[0] |
|
|
prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None |
|
|
|
|
|
return { |
|
|
"sequence": sequence, |
|
|
"mutation_detected": bool(pred), |
|
|
"confidence": round(prob, 3) if prob else "N/A" |
|
|
} |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
<h1 style="text-align:center;">🧬 DNA Mutation Analyzer</h1> |
|
|
<p style="text-align:center;"> |
|
|
Enter a DNA sequence to check for mutations using the ML model. |
|
|
</p> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
seq_input = gr.Textbox( |
|
|
label="DNA Sequence", |
|
|
placeholder="Enter sequence like ATGCGTACGTTAGC...", |
|
|
lines=2, |
|
|
) |
|
|
analyze_btn = gr.Button("🔍 Analyze Sequence") |
|
|
result = gr.JSON(label="Analysis Result") |
|
|
|
|
|
analyze_btn.click(fn=predict_sequence, inputs=seq_input, outputs=result) |
|
|
|
|
|
|
|
|
def api_predict(payload: dict): |
|
|
seq = payload.get("sequence", "") |
|
|
return predict_sequence(seq) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.launch(share=True, ssr_mode=False) |