Skip to content

Commit d31fc79

Browse files
authored
Merge pull request #53 from grace-sng7/creating_augmented_suggester
2 parents 5928968 + a474e5d commit d31fc79

8 files changed

+6161
-1884
lines changed

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,22 @@ domain_expertises = modeler.suggest_domain_expertises(all_factors)
4949
# Suggest a set of potential confounders
5050
suggested_confounders = modeler.suggest_confounders(treatment, outcome, all_factors, domain_expertises)
5151

52-
# Suggest pair-wise relationship between variables
52+
# Suggest pair-wise relationships between variables
5353
suggested_dag = modeler.suggest_relationships(treatment, outcome, all_factors, domain_expertises, RelationshipStrategy.Pairwise)
5454
```
5555

56+
### Retrieval Augmented Generation (RAG)-based Modeler
5657

58+
```python
59+
# Create instance of Modeler
60+
modeler = ModelSuggester('gpt-4')
61+
62+
treatment = "smoking"
63+
outcome = "lung cancer"
64+
65+
# Suggest pair-wise relationship between two given variables, utilizing CauseNet for RAG
66+
suggested_relationship = modeler.suggest_relationships(treatment, outcome)
67+
```
5768

5869
### Identifier
5970

poetry.lock

Lines changed: 5846 additions & 1879 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ networkx = "<=3.2.1"
5555
guidance = ">=0.2"
5656
openai = ">=1.70"
5757
pydantic = ">=2.11"
58+
langchain = ">=0.3.25"
59+
langchain-chroma = ">=0.2.4"
60+
langchain-community = ">=0.3.24"
61+
langchain-core = ">=0.3.60"
62+
langchain-huggingface = ">=0.2.0"
63+
langchain-openai = ">=0.3.17"
64+
rank-bm25 = ">=0.2.2"
65+
sentence-transformers = ">=4.1.0"
5866

5967
[tool.poetry.group.dev.dependencies]
6068
poethepoet = "^0.33.0"
@@ -110,7 +118,7 @@ _isort_check = 'isort --check .'
110118

111119
# testing tasks
112120
test = "pytest -v -m 'not advanced' --durations=0 --durations-min=60.0"
113-
test_no_notebooks= "pytest -v -m 'not advanced and not notebook' --durations=0 --durations-min=60.0"
121+
test_no_notebooks = "pytest -v -m 'not advanced and not notebook' --durations=0 --durations-min=60.0"
114122
test_durations = "poetry run poe test --store-durations"
115123
test_advanced = "pytest -v"
116124
test_focused = "pytest -v -m 'focused'"
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
import re
3+
4+
from .simple_model_suggester import SimpleModelSuggester
5+
from pywhyllm.utils.data_loader import *
6+
from pywhyllm.utils.augmented_model_suggester_utils import *
7+
8+
9+
class AugmentedModelSuggester(SimpleModelSuggester):
10+
def __init__(self, llm, file_path: str = 'data/causenet-precision.jsonl.bz2'):
11+
super().__init__(llm)
12+
self.file_path = file_path
13+
14+
logging.basicConfig(level=logging.INFO)
15+
url = "https://groups.uni-paderborn.de/wdqa/causenet/causality-graphs/causenet-precision.jsonl.bz2"
16+
success = download_causenet(url, file_path)
17+
18+
if success:
19+
print(f"File downloaded to {file_path}")
20+
json_data = load_causenet_json(file_path)
21+
self.causenet_dict = create_causenet_dict(json_data)
22+
else:
23+
print("Download failed")
24+
25+
def suggest_pairwise_relationship(self, variable1: str, variable2: str):
26+
result = find_top_match_in_causenet(self.causenet_dict, variable1, variable2)
27+
if result:
28+
source_text = get_source_text(result)
29+
retriever = split_data_and_create_vectorstore_retriever(source_text)
30+
response = query_llm(variable1, variable2, source_text, retriever)
31+
else:
32+
response = query_llm(variable1, variable2)
33+
34+
answer = re.findall(r'<answer>(.*?)</answer>', response)
35+
answer = [ans.strip() for ans in answer]
36+
answer_str = "".join(answer)
37+
38+
if answer_str == "A":
39+
return [variable1, variable2, response]
40+
elif answer_str == "B":
41+
return [variable2, variable1, response]
42+
elif answer_str == "C":
43+
return [None, None, response]
44+
else:
45+
assert False, "Invalid answer from LLM: " + answer_str

pywhyllm/suggesters/simple_model_suggester.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def suggest_pairwise_relationship(self, variable1: str, variable2: str):
5555
answer = [ans.strip() for ans in answer]
5656
answer_str = "".join(answer)
5757

58-
if (answer_str == "A"):
58+
if answer_str == "A":
5959
return [variable1, variable2, description]
60-
elif (answer_str == "B"):
60+
elif answer_str == "B":
6161
return [variable2, variable1, description]
62-
elif (answer_str == "C"):
62+
elif answer_str == "C":
6363
return [None, None, description] # maybe we want to save the description in this case too
6464
else:
6565
assert False, "Invalid answer from LLM: " + answer_str

pywhyllm/utils/__init__.py

Whitespace-only changes.
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
from langchain_text_splitters import RecursiveCharacterTextSplitter
3+
from langchain_core.documents import Document
4+
from langchain_chroma import Chroma
5+
from langchain_huggingface import HuggingFaceEmbeddings
6+
from langchain_openai import ChatOpenAI
7+
from langchain_core.prompts import ChatPromptTemplate
8+
from langchain.chains import create_retrieval_chain
9+
from langchain.chains.combine_documents import create_stuff_documents_chain
10+
import numpy as np
11+
from rank_bm25 import BM25Okapi
12+
from sentence_transformers import SentenceTransformer, util
13+
14+
15+
def find_top_match_in_causenet(causenet_dict, variable1, variable2, threshold=0.7):
16+
# Sample dictionary
17+
pair_strings = [
18+
f"{causenet_dict[key]['causal_relation']['cause']}-{causenet_dict[key]['causal_relation']['effect']}"
19+
for key in causenet_dict]
20+
21+
# Tokenize for BM25
22+
tokenized_pairs = [text.split() for text in pair_strings]
23+
bm25 = BM25Okapi(tokenized_pairs)
24+
25+
# Original and reverse queries
26+
query = variable1 + "-" + variable2
27+
reverse_query = variable2 + "-" + variable1
28+
tokenized_query = query.split()
29+
tokenized_reverse_query = reverse_query.split()
30+
31+
# Combine tokens from both queries (remove duplicates)
32+
combined_query = list(set(tokenized_query + tokenized_reverse_query))
33+
34+
# Get top-k candidates using BM25 with combined query
35+
k = 5
36+
scores = bm25.get_scores(combined_query)
37+
top_k_indices = np.argsort(scores)[::-1][:k]
38+
candidate_pairs = [pair_strings[i] for i in top_k_indices]
39+
40+
# Apply SBERT to candidates
41+
model = SentenceTransformer('all-MiniLM-L6-v2')
42+
query_embedding = model.encode(query, convert_to_tensor=True)
43+
reverse_query_embedding = model.encode(reverse_query, convert_to_tensor=True)
44+
candidate_embeddings = model.encode(candidate_pairs, convert_to_tensor=True)
45+
46+
# Compute similarities for both original and reverse queries
47+
similarities = util.cos_sim(query_embedding, candidate_embeddings).flatten()
48+
reverse_similarities = util.cos_sim(reverse_query_embedding, candidate_embeddings).flatten()
49+
50+
# Take the maximum similarity for each candidate (original or reverse)
51+
max_similarities = np.maximum(similarities, reverse_similarities)
52+
53+
# Get the top match and its similarity score
54+
top_idx = np.argmax(max_similarities)
55+
top_similarity = max_similarities[top_idx]
56+
top_pair = candidate_pairs[top_idx]
57+
58+
# Check if the top similarity meets the threshold
59+
if top_similarity >= threshold:
60+
print(f"Best match: {top_pair} (Similarity: {top_similarity:.4f})")
61+
return causenet_dict[top_pair]
62+
else:
63+
print(f"No match found with similarity above {threshold} (Best similarity: {top_similarity:.4f})")
64+
return None
65+
66+
67+
def get_source_text(causenet_query_result):
68+
source_text = ""
69+
if causenet_query_result:
70+
for item in causenet_query_result["sources"]:
71+
if item["type"] == 'wikipedia_sentence' or item["type"] == 'clueweb12_sentence':
72+
source_text += item["payload"]["sentence"] + " "
73+
74+
return source_text
75+
76+
77+
def split_data_and_create_vectorstore_retriever(source_text):
78+
document = Document(page_content=source_text)
79+
80+
# Initialize the text splitter
81+
text_splitter = RecursiveCharacterTextSplitter(
82+
chunk_size=100, # Adjust chunk size as needed
83+
chunk_overlap=20 # Overlap for context
84+
)
85+
# Split the documents
86+
splits = text_splitter.split_documents([document])
87+
88+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
89+
90+
# Create a vector store from the document splits
91+
vectorstore = Chroma.from_documents(
92+
documents=splits,
93+
embedding=embeddings,
94+
persist_directory="./chroma_db" # Optional: Save to disk for reuse
95+
)
96+
97+
# Create a retriever from the vector store
98+
retriever = vectorstore.as_retriever(
99+
search_type="similarity",
100+
search_kwargs={"k": 5} # Retrieve top 5 relevant chunks
101+
)
102+
103+
return retriever
104+
105+
106+
def query_llm(variable1, variable2, source_text=None, retriever=None):
107+
# Initialize the language model
108+
llm = ChatOpenAI(model="gpt-4")
109+
110+
if source_text:
111+
system_prompt = """You are a helpful assistant for causal reasoning.
112+
113+
Context: {context}
114+
"""
115+
else:
116+
system_prompt = """You are a helpful assistant for causal reasoning.
117+
"""
118+
119+
# prompt template
120+
prompt = ChatPromptTemplate.from_messages([
121+
("system", system_prompt),
122+
("human", "{input}")
123+
])
124+
125+
query = f"""Which cause-and-effect-relationship is more likely? Provide reasoning and you must give your final answer (A, B, or C) in <answer> </answer> tags with the letter only.
126+
A. {variable1} causes {variable2} B. {variable2} causes {variable1} C. neither {variable1} nor {variable2} cause each other."""
127+
128+
# Define the system prompt
129+
if source_text:
130+
# Create a document chain to combine retrieved documents
131+
question_answer_chain = create_stuff_documents_chain(llm, prompt)
132+
133+
# Create the RAG chain
134+
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
135+
136+
response = rag_chain.invoke({"input": query})
137+
return response['answer']
138+
139+
140+
else:
141+
default_chain = prompt | llm
142+
response = default_chain.invoke({"input": query})
143+
return response.content

pywhyllm/utils/data_loader.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import bz2
2+
import json
3+
import os
4+
import requests
5+
from tqdm import tqdm
6+
import logging
7+
8+
9+
def download_causenet(url: str, file_path: str) -> bool:
10+
"""
11+
Download the CauseNet-Precision dataset from causenet.org and save it to the specified path.
12+
13+
The CauseNet-Precision dataset is a subset of the CauseNet
14+
knowledge base containing high-precision causal relations extracted from web sources.
15+
The dataset is described in the following paper:
16+
17+
Citation:
18+
Stefan Heindorf, Yan Scholten, Henning Wachsmuth, Axel-Cyrille Ngonga Ngomo, and Martin Potthast.
19+
2020. CauseNet: Towards a Causality Graph Extracted from the Web. In Proceedings of the 29th ACM
20+
International Conference on Information &amp; Knowledge Management (CIKM '20). Association for
21+
Computing Machinery, New York, NY, USA, 3023–3030. https://doi.org/10.1145/3340531.3412763
22+
23+
TODO: Add license
24+
25+
Args:
26+
url (str): The URL of the file to download.
27+
file_path (str): The local path where the file will be saved.
28+
29+
Returns:
30+
bool: True if the download was successful, False otherwise.
31+
"""
32+
try:
33+
# Ensure the output directory exists
34+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
35+
36+
# Send a GET request to the URL
37+
response = requests.get(url, stream=True)
38+
39+
# Check if the request was successful
40+
if response.status_code != 200:
41+
logging.error(f"Failed to download file from {url}. Status code: {response.status_code}")
42+
return False
43+
44+
# Get the total file size for progress bar (if available)
45+
total_size = int(response.headers.get("content-length", 0))
46+
47+
# Download and save the file with a progress bar
48+
with open(file_path, "wb") as file, tqdm(
49+
desc="Downloading",
50+
total=total_size,
51+
unit="B",
52+
unit_scale=True,
53+
unit_divisor=1024,
54+
) as progress_bar:
55+
for chunk in response.iter_content(chunk_size=8192):
56+
if chunk:
57+
file.write(chunk)
58+
progress_bar.update(len(chunk))
59+
60+
logging.info(f"File downloaded successfully to {file_path}")
61+
return True
62+
63+
except requests.exceptions.RequestException as e:
64+
logging.error(f"Error downloading file from {url}: {e}")
65+
return False
66+
except OSError as e:
67+
logging.error(f"Error saving file to {file_path}: {e}")
68+
return False
69+
70+
71+
def load_causenet_json(file_path):
72+
json_data = []
73+
print("Loading CauseNet using json")
74+
with bz2.open(file_path, 'rt',
75+
encoding='utf-8') as file:
76+
# Read each line and parse as JSON
77+
for line in file:
78+
line = line.strip() # Remove trailing newlines
79+
if line: # Skip empty lines
80+
json_obj = json.loads(line) # Parse the line as JSON
81+
json_data.append(json_obj) # Add to list
82+
print("Done loading CauseNet using json")
83+
return json_data
84+
85+
86+
def create_causenet_dict(json_data):
87+
causenet_dict = {}
88+
print("Creating dictionary from CauseNet json data")
89+
for item in json_data:
90+
cause = item['causal_relation']['cause']['concept']
91+
effect = item['causal_relation']['effect']['concept']
92+
key = cause + "-" + effect
93+
94+
if key not in causenet_dict:
95+
causenet_dict[key] = {
96+
'causal_relation': {'cause': cause, 'effect': effect},
97+
'sources': item['sources']
98+
}
99+
else:
100+
# Append sources to existing list
101+
causenet_dict[key]['sources'].extend(item['sources'])
102+
print("Done creating dictionary from CauseNet json data")
103+
return causenet_dict

0 commit comments

Comments
 (0)