Skip to content

Latest commit

 

History

History
645 lines (523 loc) · 19.8 KB

File metadata and controls

645 lines (523 loc) · 19.8 KB

Code-Ready Starting Point

# Step 1: Install core dependencies
"""
pip install langchain langgraph openai anthropic
pip install transformers torch accelerate
pip install neo4j redis fastapi
pip install scanpy anndata  # for omics data
"""

# Step 2: Basic Orchestrator (5 minutes to set up)
from langgraph.graph import StateGraph, END
from langchain_openai import ChatOpenAI
from typing import TypedDict, List, Dict
import json

class HealthQueryState(TypedDict):
    """State passed between agents"""
    query: str
    data_type: str
    raw_data: Dict
    validated_data: Dict
    model_outputs: Dict
    kg_results: Dict
    risk_assessment: Dict
    recommendations: List[str]
    final_report: str

# Initialize orchestrator LLM
orchestrator_llm = ChatOpenAI(
    model="gpt-4-turbo-preview",
    temperature=0
)

# Define agent nodes
def route_query(state: HealthQueryState) -> HealthQueryState:
    """Determine which agents to invoke based on query"""
    query = state["query"]
    
    # Simple routing logic (can be made more sophisticated)
    prompt = f"""
    Given this health query: "{query}"
    
    Determine which data types are involved:
    - gene_expression
    - microbiome
    - clinical_records
    - medical_imaging
    - radiation
    - wearables
    
    Return as JSON: {{"data_types": [...], "priority": "high/medium/low"}}
    """
    
    response = orchestrator_llm.invoke(prompt)
    routing = json.loads(response.content)
    
    state["data_type"] = routing["data_types"][0]  # Primary type
    return state

def validate_data(state: HealthQueryState) -> HealthQueryState:
    """Validate and preprocess input data"""
    data_type = state["data_type"]
    raw_data = state["raw_data"]
    
    if data_type == "gene_expression":
        # Example validation for gene expression
        validated = validate_gene_expression_data(raw_data)
        state["validated_data"] = validated
        
    elif data_type == "microbiome":
        validated = validate_microbiome_data(raw_data)
        state["validated_data"] = validated
        
    elif data_type == "radiation":
        validated = validate_radiation_data(raw_data)
        state["validated_data"] = validated
        
    return state

def run_foundation_models(state: HealthQueryState) -> HealthQueryState:
    """Execute relevant foundation models"""
    data_type = state["data_type"]
    validated_data = state["validated_data"]
    
    if data_type == "gene_expression":
        # Run scGPT or similar
        results = run_gene_expression_models(validated_data)
        
    elif data_type == "microbiome":
        results = run_microbiome_analysis(validated_data)
        
    elif data_type == "radiation":
        # No foundation model, use physics-based calculation
        results = calculate_radiation_risks(validated_data)
    
    state["model_outputs"] = results
    return state

def query_knowledge_graphs(state: HealthQueryState) -> HealthQueryState:
    """Query relevant knowledge graphs"""
    model_outputs = state["model_outputs"]
    
    # Extract key entities for KG query
    entities = extract_entities(model_outputs)
    
    # Query PrimeKG, UMLS, etc.
    kg_results = {
        "disease_associations": query_primekg(entities),
        "treatment_options": query_drugbank(entities),
        "literature_evidence": query_pubmed_kg(entities)
    }
    
    state["kg_results"] = kg_results
    return state

def assess_risks(state: HealthQueryState) -> HealthQueryState:
    """Integrate findings and assess health risks"""
    model_outputs = state["model_outputs"]
    kg_results = state["kg_results"]
    
    # Use LLM for reasoning
    prompt = f"""
    Based on the following analysis:
    
    Model Findings: {json.dumps(model_outputs, indent=2)}
    Knowledge Graph Results: {json.dumps(kg_results, indent=2)}
    
    Provide:
    1. Health risk assessment (with confidence levels)
    2. Evidence-based reasoning
    3. Uncertainty quantification
    
    Format as JSON with 'risks', 'confidence', 'reasoning' fields.
    """
    
    response = orchestrator_llm.invoke(prompt)
    risk_assessment = json.loads(response.content)
    
    state["risk_assessment"] = risk_assessment
    return state

def generate_recommendations(state: HealthQueryState) -> HealthQueryState:
    """Generate actionable mitigation strategies"""
    risk_assessment = state["risk_assessment"]
    kg_results = state["kg_results"]
    
    prompt = f"""
    Given these health risks:
    {json.dumps(risk_assessment, indent=2)}
    
    And available treatments:
    {json.dumps(kg_results.get('treatment_options', {}), indent=2)}
    
    Generate evidence-based mitigation strategies:
    1. Immediate actions (0-1 week)
    2. Short-term actions (1-3 months)
    3. Long-term monitoring/prevention
    
    Include:
    - Specific, actionable recommendations
    - Evidence citations
    - Confidence levels
    - When to seek medical consultation
    
    CRITICAL: Always include disclaimer that this is not medical advice.
    """
    
    response = orchestrator_llm.invoke(prompt)
    recommendations = response.content
    
    state["recommendations"] = recommendations
    return state

def generate_report(state: HealthQueryState) -> HealthQueryState:
    """Create final human-readable report"""
    query = state["query"]
    risk_assessment = state["risk_assessment"]
    recommendations = state["recommendations"]
    
    report = f"""
# HEALTH ASSESSMENT REPORT

## Query
{query}

## Risk Assessment
{json.dumps(risk_assessment, indent=2)}

## Recommendations
{recommendations}

## Methodology
- Foundation Models Used: {state.get('models_used', 'N/A')}
- Knowledge Graphs Queried: PrimeKG, UMLS, DrugBank
- Confidence Level: {risk_assessment.get('confidence', 'N/A')}

## IMPORTANT DISCLAIMER
This analysis is generated by AI models and knowledge graphs for informational 
purposes only. It is NOT medical advice. Always consult qualified healthcare 
professionals for medical decisions.

Generated: {datetime.now().isoformat()}
"""
    
    state["final_report"] = report
    return state

# Build the workflow graph
workflow = StateGraph(HealthQueryState)

# Add nodes
workflow.add_node("route", route_query)
workflow.add_node("validate", validate_data)
workflow.add_node("models", run_foundation_models)
workflow.add_node("knowledge_graphs", query_knowledge_graphs)
workflow.add_node("assess_risks", assess_risks)
workflow.add_node("recommend", generate_recommendations)
workflow.add_node("report", generate_report)

# Define edges (workflow sequence)
workflow.set_entry_point("route")
workflow.add_edge("route", "validate")
workflow.add_edge("validate", "models")
workflow.add_edge("models", "knowledge_graphs")
workflow.add_edge("knowledge_graphs", "assess_risks")
workflow.add_edge("assess_risks", "recommend")
workflow.add_edge("recommend", "report")
workflow.add_edge("report", END)

# Compile the workflow
app = workflow.compile()

# Usage example
def analyze_health_query(query: str, data_files: Dict):
    """
    Main entry point for health queries
    
    Args:
        query: Natural language health question
        data_files: Dictionary of data files by type
    
    Returns:
        Complete health assessment report
    """
    initial_state = {
        "query": query,
        "raw_data": load_data_files(data_files),
        "data_type": "",
        "validated_data": {},
        "model_outputs": {},
        "kg_results": {},
        "risk_assessment": {},
        "recommendations": [],
        "final_report": ""
    }
    
    # Run the workflow
    final_state = app.invoke(initial_state)
    
    return final_state["final_report"]

# Example usage
if __name__ == "__main__":
    # Example 1: Gene expression analysis
    report = analyze_health_query(
        query="Given gene expression data from blood sample, assess health risks",
        data_files={
            "gene_expression": "patient_blood_expression.csv"
        }
    )
    print(report)
    
    # Example 2: Microbiome analysis
    report = analyze_health_query(
        query="Given shotgun metagenomic data from stool, assess gut health",
        data_files={
            "microbiome": "patient_stool_metagenome.fastq"
        }
    )
    print(report)
    
    # Example 3: Radiation exposure
    report = analyze_health_query(
        query="20mGy radiation detected, what are health risks?",
        data_files={
            "radiation": {"dose_mGy": 20, "type": "gamma", "duration_hours": 2}
        }
    )
    print(report)

Helper Functions

# Gene expression validation
def validate_gene_expression_data(raw_data):
    """Validate and normalize gene expression data"""
    import scanpy as sc
    import anndata as ad
    
    # Load data as AnnData object
    if isinstance(raw_data, str):
        adata = sc.read_csv(raw_data)
    else:
        adata = ad.AnnData(raw_data)
    
    # Basic QC
    sc.pp.filter_cells(adata, min_genes=200)
    sc.pp.filter_genes(adata, min_cells=3)
    
    # Normalize
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    
    return adata

# Run scGPT (requires model setup)
def run_gene_expression_models(validated_data):
    """Run foundation models on gene expression data"""
    # This is a placeholder - actual implementation requires:
    # 1. Loading scGPT pretrained model
    # 2. Tokenizing gene expression data
    # 3. Running inference
    # 4. Post-processing results
    
    # Simplified example:
    results = {
        "cell_types": ["T cells", "B cells", "Monocytes"],
        "dysregulated_pathways": [
            {"name": "Inflammatory response", "score": 0.85},
            {"name": "DNA repair", "score": -0.65}
        ],
        "disease_signatures": [
            {"disease": "Chronic inflammation", "similarity": 0.78}
        ],
        "confidence": 0.82
    }
    
    return results

# Microbiome analysis
def run_microbiome_analysis(validated_data):
    """Analyze microbiome composition"""
    results = {
        "taxonomy": {
            "phylum": {"Firmicutes": 0.45, "Bacteroidetes": 0.35, "Proteobacteria": 0.15},
            "genus": {"Bacteroides": 0.25, "Faecalibacterium": 0.18}
        },
        "diversity": {
            "shannon": 3.2,
            "simpson": 0.85
        },
        "dysbiosis_score": 0.35,  # 0-1 scale
        "functional_pathways": {
            "SCFA_production": "reduced",
            "inflammatory_markers": "elevated"
        },
        "confidence": 0.75
    }
    
    return results

# Radiation risk calculation
def calculate_radiation_risks(data):
    """Physics-based radiation risk assessment"""
    dose_mGy = data.get("dose_mGy", 0)
    radiation_type = data.get("type", "gamma")
    duration = data.get("duration_hours", 1)
    
    # Convert to effective dose (mSv)
    # Simplified - real calculation more complex
    tissue_weighting = {"gamma": 1.0, "beta": 1.0, "alpha": 20.0}
    effective_dose_mSv = dose_mGy * tissue_weighting.get(radiation_type, 1.0)
    
    # Risk estimation using BEIR model (simplified)
    # Real implementation would use proper ICRP/BEIR models
    lifetime_cancer_risk = effective_dose_mSv * 5.5e-5  # per mSv
    
    results = {
        "absorbed_dose_mGy": dose_mGy,
        "effective_dose_mSv": effective_dose_mSv,
        "lifetime_cancer_risk_increase": f"{lifetime_cancer_risk:.4%}",
        "risk_category": get_risk_category(effective_dose_mSv),
        "acute_effects_expected": effective_dose_mSv > 1000,
        "confidence": 0.90  # High confidence in physics-based models
    }
    
    return results

def get_risk_category(dose_mSv):
    """Categorize radiation risk"""
    if dose_mSv < 1:
        return "Negligible"
    elif dose_mSv < 50:
        return "Low"
    elif dose_mSv < 250:
        return "Moderate"
    elif dose_mSv < 1000:
        return "High"
    else:
        return "Very High - Immediate medical attention required"

# Knowledge graph queries
def query_primekg(entities):
    """Query PrimeKG for disease associations"""
    # Placeholder - requires Neo4j connection to PrimeKG
    # Real implementation would use Cypher queries
    
    return {
        "diseases": [
            {"name": "Chronic inflammation", "confidence": 0.78, "evidence_count": 150},
            {"name": "Autoimmune disorder risk", "confidence": 0.65, "evidence_count": 89}
        ],
        "gene_associations": [
            {"gene": "IL6", "disease": "Inflammation", "confidence": 0.92}
        ]
    }

def query_drugbank(entities):
    """Query DrugBank for treatment options"""
    # Placeholder for DrugBank API
    
    return {
        "treatments": [
            {
                "drug": "Anti-inflammatory agents",
                "mechanism": "COX-2 inhibition",
                "evidence_level": "A",
                "side_effects": ["GI upset", "Cardiovascular risk"]
            }
        ]
    }

def extract_entities(model_outputs):
    """Extract key entities from model outputs for KG queries"""
    entities = {
        "genes": [],
        "pathways": [],
        "phenotypes": []
    }
    
    # Extract from gene expression results
    if "dysregulated_pathways" in model_outputs:
        for pathway in model_outputs["dysregulated_pathways"]:
            entities["pathways"].append(pathway["name"])
    
    # Extract from disease signatures
    if "disease_signatures" in model_outputs:
        for sig in model_outputs["disease_signatures"]:
            entities["phenotypes"].append(sig["disease"])
    
    return entities

Implementation Phases

1. Scale Gradually

Week 1: Minimum Viable Agent

# Just 3 components:
# 1. Orchestrator LLM (GPT-4)
# 2. One foundation model (e.g., scGPT for gene expression)
# 3. One knowledge graph (UMLS via API)

# Test with single data type first

Week 2-3: Add More Models

# Add microbiome analysis
# Add Clinical-T5 for clinical notes
# Connect to PrimeKG

Week 4: Polish and Expand

# Add remaining modalities
# Implement fusion layer
# Build evaluation framework

2. Handle the "No Foundation Model" Problem**

Account for questions in which a specific foundation model does not exist, e.g. radiation exposure questions, questions about habitat telemetry.

class HybridAgent:
    """Combines traditional methods with LLM reasoning"""
    
    def __init__(self):
        self.traditional_model = PhysicsBasedModel()  # Or statistical model
        self.llm = LanguageModel()  # For interpretation
        self.kg = KnowledgeGraph()
        
    def analyze(self, data):
        # 1. Use traditional method for core calculation
        calculation_result = self.traditional_model.compute(data)
        
        # 2. Query KG for context
        context = self.kg.get_relevant_knowledge(calculation_result)
        
        # 3. Use LLM to synthesize human-readable insights
        interpretation = self.llm.interpret(
            calculation_result,
            context,
            prompt="Explain risks and recommendations"
        )
        
        return {
            "quantitative": calculation_result,
            "contextual": context,
            "interpretation": interpretation
        }

3. Manage Uncertainty

class UncertaintyTracker:
    """Track and propagate uncertainty through pipeline"""
    
    def __init__(self):
        self.uncertainty_log = []
        
    def add_uncertainty(self, source, confidence, reason):
        self.uncertainty_log.append({
            "source": source,
            "confidence": confidence,
            "reason": reason,
            "timestamp": datetime.now()
        })
    
    def get_overall_confidence(self):
        """Aggregate confidence across all sources"""
        if not self.uncertainty_log:
            return 0.0
        
        # Weighted geometric mean (more conservative)
        confidences = [entry["confidence"] for entry in self.uncertainty_log]
        weights = [1.0] * len(confidences)  # Equal weights, can be adjusted
        
        weighted_product = 1.0
        for conf, weight in zip(confidences, weights):
            weighted_product *= conf ** weight
        
        return weighted_product ** (1 / sum(weights))
    
    def get_confidence_breakdown(self):
        """Show where uncertainty comes from"""
        return pd.DataFrame(self.uncertainty_log)

4. Make It Production-Ready

# Add comprehensive error handling
class RobustOrchestrator:
    def __init__(self):
        self.max_retries = 3
        self.fallback_responses = {}
        
    async def execute_with_fallback(self, agent_func, *args):
        """Execute agent with retry and fallback"""
        for attempt in range(self.max_retries):
            try:
                result = await agent_func(*args)
                return result
            except ModelTimeoutError:
                if attempt < self.max_retries - 1:
                    await asyncio.sleep(2 ** attempt)  # Exponential backoff
                    continue
                else:
                    return self.fallback_responses.get(agent_func.__name__)
            except ModelError as e:
                logger.error(f"Model error in {agent_func.__name__}: {e}")
                return None
        
        return None

Example Queries and Expected Workflow Approach

Query 1: Gene Expression → Health Risks

Input: Blood gene expression CSV file Workflow:

  1. validate_data: Load as AnnData, run QC, normalize
  2. run_foundation_models:
    • scGPT for cell type annotation
    • Geneformer for pathway analysis
  3. query_knowledge_graphs:
    • PrimeKG: gene-disease associations
    • UMLS: symptom mappings
  4. assess_risks: LLM integrates findings → inflammatory markers elevated, DNA repair pathways down
  5. generate_recommendations: Anti-inflammatory diet, stress reduction, follow-up testing Output: Report with risks, confidence, evidence, recommendations

Query 2: Radiation Exposure → Health Risks

Input: "20mGray gamma radiation detected" Workflow:

  1. validate_data: Parse dose, validate units
  2. run_foundation_models: NO FOUNDATION MODEL → use physics equations
    • Calculate effective dose (20 mSv)
    • BEIR model for cancer risk (+0.1%)
  3. query_knowledge_graphs:
    • Medical KG: radiation countermeasures
    • Literature: latest ICRP guidelines
  4. assess_risks: LLM interprets → "Low but non-negligible risk"
  5. generate_recommendations: Monitor for symptoms, potassium iodide if thyroid exposure, follow-up in 3 months Output: Physics-based calculation + evidence-based recommendations

Query 3: Microbiome → Health Status

Input: Shotgun metagenomic FASTQ file Workflow:

  1. validate_data: Quality filter, taxonomic classification
  2. run_foundation_models:
    • gNOMO pipeline for multi-omics
    • MintTea for dysbiosis detection
  3. query_knowledge_graphs:
    • Microbiome KG: species-disease links
    • PrimeKG: metabolite-health associations
  4. assess_risks: LLM integrates → reduced microbial diversity, elevated inflammatory species
  5. generate_recommendations: Probiotic strains, prebiotic fiber, dietary changes Output: Microbiome composition + health implications + actionable diet/supplement recommendations

Expected Performance

  • Accuracy: 70-85% agreement with expert assessments (where ground truth exists)
  • Speed: 30-120 seconds per query (depending on complexity)
  • Cost: $0.05-0.20 per query
  • Confidence: Should properly quantify uncertainty 90%+ of the time

Final Critical Recommendations

  1. Start with LangGraph - it's the most mature framework for this use case
  2. Don't try to build everything at once - start with one data type
  3. Focus on prompt engineering - 80% of your success will be good prompts for the orchestrator
  4. Build evaluation from day 1 - you need to know if it's working
  5. Add medical disclaimers everywhere - this is critical for liability
  6. Plan for the "no foundation model" problem - use hybrid approaches
  7. Make uncertainty visible - users need to know confidence levels