This repository contains the implementation of the approach discussed in AUTOCT: Automating Interpretable Clinical Trial Prediction with LLM Agents.
- Clone the repository and navigate to the project directory.
- Create the conda environment:
conda env create -f environment.yml
- Activate the environment:
conda activate autoct
- Set up required environment variables in a
.env
file or your shell.- The following environment variables are required for generating the retrieval dataset
- PubMed
ENTREZ_API_KEY
ENTREZ_EMAIL
- A free account is available on https://account.ncbi.nlm.nih.gov/ to increase the API request limit
- PubMed
- The following environment variables are required for running the agent
- OpenAI
OPENAI_API_KEY
- OpenAI
- The following environment variables are required for generating the retrieval dataset
- The retrieval uses
pgvector
as the default backend. To run this locally, navigate todb/
and rundocker-compose up
- Run the data generation notebooks in order to prepare datasets.
- Train the MCTS model using
scripts/train_mcts.py
. - Use
scripts/predict.py
to make predictions on new clinical trials.
Jupyter notebooks for generating the datasets:
0_generate_ctg.ipynb
: Generates clinical trial data from ClinicalTrials.gov.1_generate_trialbench.ipynb
: Creates trial benchmark datasets.2_clinical_trial_retrieval.ipynb
: Creates the RAG index for retrieval of clinical trial information.3_generate_pubmed.ipynb
: Downloads and stores PubMed data.- A fixed list of PubMed document IDs that replicates what was used for the experiments is in
pmids.parquet
- A fixed list of PubMed document IDs that replicates what was used for the experiments is in
4_clean_pubmed.ipynb
: Cleans and prepares PubMed datasets.5_pubmed_retrieval.ipynb
: Creates the RAG index for retrieval of PubMed data.
Python scripts for running the system:
run_agent.py
: Executes a single iteration of the end to end LLM agent.train_mcts.py
: Runs the Monte Carlo Tree Search (MCTS) algorithm for feature exploration.predict.py
: Extracts the best performing model from a run of MCTS and runs prediction for a given task and trial ID.
Core implementation modules:
agent.py
: Main agent logic using DSPy framework for LLM interactions and feature engineering.treesearch.py
: Implements Monte Carlo Tree Search for exploring feature engineering plans.pubmed.py
: Handles PubMed data retrieval and processing using BioPython's Entrez API.tools.py
: Search tools for clinical trials and PubMed using embeddings.utils.py
: General utility functions.globals.py
: Global configurations and shared resources.
These are subsets of the TrialBench Dataset as used by the paper for different clinical trial prediction tasks:
adverse_event/
: Predict adverse events in clinical trials.failure_reason/
: Predict reasons for trial failure.mortality/
: Predict mortality outcomes.patient_dropout/
: Predict patient dropout rates.trial_approval/
: Predict trial approval likelihood.
Each task directory contains train/validation/test splits in Parquet format.