Skip to content

Commit 2bb8403

Browse files
wip: trace generated data
1 parent 03366bb commit 2bb8403

12 files changed

Lines changed: 130 additions & 144 deletions

graphgen/bases/base_generator.py

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,72 +21,67 @@ def build_prompt(
2121

2222
@staticmethod
2323
@abstractmethod
24-
def parse_response(response: str) -> Any:
24+
def parse_response(response: str) -> list[dict]:
2525
"""Parse the LLM response and return the generated QAs"""
2626

2727
async def generate(
2828
self,
2929
batch: tuple[
3030
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
3131
],
32-
) -> dict[str, Any]:
32+
) -> list[dict]:
3333
"""
3434
Generate QAs based on a given batch.
3535
:param batch
3636
:return: QA pairs
3737
"""
38-
result = {}
3938
prompt = self.build_prompt(batch)
4039
response = await self.llm_client.generate_answer(prompt)
4140
qa_pairs = self.parse_response(response) # generate one or more QA pairs
42-
result.update(qa_pairs)
43-
return result
41+
return qa_pairs
4442

4543
@staticmethod
4644
def format_generation_results(
4745
results: list[dict], output_data_format: str
4846
) -> list[dict[str, Any]]:
4947

5048
flat_results = []
51-
for item in results:
52-
for _, qa_data in item.items():
53-
question = qa_data.get("question", "")
54-
answer = qa_data.get("answer", "")
55-
if "options" in qa_data and qa_data["options"]:
56-
options = qa_data["options"]
57-
options_str = "\n".join(
58-
[f"{key}. {options[key]}" for key in sorted(options.keys())]
59-
)
60-
question += f"\nOptions:\n{options_str}"
49+
for qa_data in results:
50+
question = qa_data.get("question", "")
51+
answer = qa_data.get("answer", "")
52+
if "options" in qa_data and qa_data["options"]:
53+
options = qa_data["options"]
54+
options_str = "\n".join(
55+
[f"{key}. {options[key]}" for key in sorted(options.keys())]
56+
)
57+
question += f"\nOptions:\n{options_str}"
6158

62-
if output_data_format == "Alpaca":
63-
flat_results.append(
64-
{
65-
"instruction": question,
66-
"input": "",
67-
"output": answer,
68-
}
69-
)
70-
elif output_data_format == "Sharegpt":
71-
flat_results.append(
72-
{
73-
"conversations": [
74-
{"from": "human", "value": question},
75-
{"from": "gpt", "value": answer},
76-
]
77-
}
78-
)
79-
elif output_data_format == "ChatML":
80-
flat_results.append(
81-
{
82-
"messages": [
83-
{"role": "user", "content": question},
84-
{"role": "assistant", "content": answer},
85-
]
86-
}
87-
)
88-
else:
89-
raise ValueError(
90-
f"Unknown output data format: {output_data_format}"
91-
)
59+
if output_data_format == "Alpaca":
60+
flat_results.append(
61+
{
62+
"instruction": question,
63+
"input": "",
64+
"output": answer,
65+
}
66+
)
67+
elif output_data_format == "Sharegpt":
68+
flat_results.append(
69+
{
70+
"conversations": [
71+
{"from": "human", "value": question},
72+
{"from": "gpt", "value": answer},
73+
]
74+
}
75+
)
76+
elif output_data_format == "ChatML":
77+
flat_results.append(
78+
{
79+
"messages": [
80+
{"role": "user", "content": question},
81+
{"role": "assistant", "content": answer},
82+
]
83+
}
84+
)
85+
else:
86+
raise ValueError(f"Unknown output data format: {output_data_format}")
9287
return flat_results

graphgen/models/generator/aggregated_generator.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from graphgen.bases import BaseGenerator
55
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6-
from graphgen.utils import compute_content_hash, detect_main_language, logger
6+
from graphgen.utils import detect_main_language, logger
77

88

99
class AggregatedGenerator(BaseGenerator):
@@ -101,30 +101,26 @@ async def generate(
101101
batch: tuple[
102102
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
103103
],
104-
) -> dict[str, Any]:
104+
) -> list[dict]:
105105
"""
106106
Generate QAs based on a given batch.
107107
:param batch
108108
:return: QA pairs
109109
"""
110-
result = {}
111110
rephrasing_prompt = self.build_prompt(batch)
112111
response = await self.llm_client.generate_answer(rephrasing_prompt)
113112
context = self.parse_rephrased_text(response)
114113
if not context:
115-
return result
114+
return []
116115
question_generation_prompt = self._build_prompt_for_question_generation(context)
117116
response = await self.llm_client.generate_answer(question_generation_prompt)
118117
question = self.parse_response(response)["question"]
119118
if not question:
120-
return result
119+
return []
121120
logger.debug("Question: %s", question)
122121
logger.debug("Answer: %s", context)
123122
qa_pairs = {
124-
compute_content_hash(question): {
125-
"question": question,
126-
"answer": context,
127-
}
123+
"question": question,
124+
"answer": context,
128125
}
129-
result.update(qa_pairs)
130-
return result
126+
return [qa_pairs]

graphgen/models/generator/atomic_generator.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from graphgen.bases import BaseGenerator
55
from graphgen.templates import ATOMIC_GENERATION_PROMPT
6-
from graphgen.utils import compute_content_hash, detect_main_language, logger
6+
from graphgen.utils import detect_main_language, logger
77

88

99
class AtomicGenerator(BaseGenerator):
@@ -23,7 +23,7 @@ def build_prompt(
2323
return prompt
2424

2525
@staticmethod
26-
def parse_response(response: str) -> dict:
26+
def parse_response(response: str) -> list[dict]:
2727
"""
2828
AtomicGenerator normally generates one QA pair per response.
2929
So we just need to parse one QA pair from the response.
@@ -38,15 +38,10 @@ def parse_response(response: str) -> dict:
3838
answer = answer_match.group(1).strip()
3939
else:
4040
logger.warning("Failed to parse response: %s", response)
41-
return {}
41+
return []
4242

4343
question = question.strip('"').strip("'")
4444
answer = answer.strip('"').strip("'")
4545
logger.debug("Question: %s", question)
4646
logger.debug("Answer: %s", answer)
47-
return {
48-
compute_content_hash(question): {
49-
"question": question,
50-
"answer": answer,
51-
}
52-
}
47+
return [{"question": question, "answer": answer}]

graphgen/models/generator/cot_generator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,28 +100,25 @@ async def generate(
100100
batch: tuple[
101101
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
102102
],
103-
) -> dict[str, Any]:
103+
) -> list[dict]:
104104
"""
105105
Generate QAs based on a given batch.
106106
:param batch
107107
:return: QA pairs
108108
"""
109-
result = {}
110109
prompt = self.build_prompt(batch)
111110
response = await self.llm_client.generate_answer(prompt)
112111
response = self.parse_response(response)
113112
if not response:
114-
return result
113+
return []
115114
question, reasoning_path = response["question"], response["reasoning_path"]
116115
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
117116
cot_answer = await self.llm_client.generate_answer(prompt)
118117
logger.debug("CoT Answer: %s", cot_answer)
119-
qa_pairs = {
120-
compute_content_hash(question): {
118+
return [
119+
{
121120
"question": question,
122121
"answer": cot_answer,
123122
"reasoning_path": reasoning_path,
124123
}
125-
}
126-
result.update(qa_pairs)
127-
return result
124+
]

graphgen/models/generator/fill_in_blank_generator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None:
1212
self.num_of_questions = num_of_questions
1313

1414
@staticmethod
15-
def parse_response(response: str) -> Any:
15+
def parse_response(response: str) -> list[dict]:
1616
"""
1717
Parse fill-in-the-blank QA pairs from the LLM response.
1818
Each QA pair contains question text with placeholders and the correct answer(s).
@@ -21,14 +21,14 @@ def parse_response(response: str) -> Any:
2121
:return: Dictionary mapping question hash to question data, where each
2222
value is a dict with "question", "answer", and "answers" keys
2323
"""
24-
qa_pairs = {}
24+
qa_pairs = []
2525

2626
# Extract all QA pair blocks
2727
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
2828

2929
if not qa_blocks:
3030
logger.warning("No QA pairs found in response: %s", response)
31-
return {}
31+
return qa_pairs
3232

3333
for block in qa_blocks:
3434
# Extract and clean question text
@@ -55,13 +55,13 @@ def parse_response(response: str) -> Any:
5555
logger.warning("No valid answers found in: %s", answer_text)
5656
continue
5757

58-
# Build result entry with question hash as key
59-
question_hash = compute_content_hash(question)
60-
qa_pairs[question_hash] = {
61-
"question": question,
62-
"answer": answer_text, # Original answer text with commas
63-
"answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"]
64-
}
58+
qa_pairs.append(
59+
{
60+
"question": question,
61+
"answer": answer_text, # Original answer text with commas
62+
"answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"]
63+
}
64+
)
6565

6666
logger.debug(
6767
"Successfully parsed fill-in-the-blank question: %s", question[:50]

graphgen/models/generator/multi_answer_generator.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None:
1212
self.num_of_questions = num_of_questions
1313

1414
@staticmethod
15-
def parse_response(response: str) -> Any:
15+
def parse_response(response: str) -> list[dict]:
1616
"""
1717
Parse multiple-answer QA pairs from the LLM response.
1818
Each QA pair contains question text, four options, and the correct answers (one or more).
@@ -21,14 +21,14 @@ def parse_response(response: str) -> Any:
2121
:return: Dictionary mapping question hash to question data, where each
2222
value is a dict with "question", "options", and "answer" keys
2323
"""
24-
qa_pairs = {}
24+
qa_pairs = []
2525

2626
# Extract all QA pair blocks
2727
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
2828

2929
if not qa_blocks:
3030
logger.warning("No QA pairs found in response: %s", response)
31-
return {}
31+
return qa_pairs
3232

3333
for block in qa_blocks:
3434
# Extract and clean question text
@@ -61,7 +61,9 @@ def parse_response(response: str) -> Any:
6161
logger.warning("Failed to parse answer from block: %s", block)
6262
continue
6363
answer_text = ans_match.group(1).strip().strip('"').strip("'")
64-
answers = [ans.strip().upper() for ans in answer_text.split(",") if ans.strip()]
64+
answers = [
65+
ans.strip().upper() for ans in answer_text.split(",") if ans.strip()
66+
]
6567
invalid_answers = [ans for ans in answers if ans not in options]
6668
if invalid_answers:
6769
logger.warning(
@@ -76,13 +78,13 @@ def parse_response(response: str) -> Any:
7678
logger.warning("No valid answers found in: %s", answer_text)
7779
continue
7880

79-
# Build result entry with question hash as key
80-
question_hash = compute_content_hash(question)
81-
qa_pairs[question_hash] = {
82-
"question": question,
83-
"options": options, # Dict like {"A": "text", "B": "text", ...}
84-
"answer": ", ".join(answers),
85-
}
81+
qa_pairs.append(
82+
{
83+
"question": question,
84+
"options": options, # Dict like {"A": "text", "B": "text", ...}
85+
"answers": answers, # List of correct answers: ["A", "C"]
86+
}
87+
)
8688

8789
logger.debug("Successfully parsed MAQ: %s", question[:50])
8890

graphgen/models/generator/multi_choice_generator.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None:
1212
self.num_of_questions = num_of_questions
1313

1414
@staticmethod
15-
def parse_response(response: str) -> Any:
15+
def parse_response(response: str) -> list[dict]:
1616
"""
1717
Parse multiple choice QA pairs from the LLM response.
1818
Each QA pair contains question text, four options, and the correct answer.
@@ -21,14 +21,14 @@ def parse_response(response: str) -> Any:
2121
:return: Dictionary mapping question hash to question data, where each
2222
value is a dict with "question", "options", and "answer" keys
2323
"""
24-
qa_pairs = {}
24+
qa_pairs = []
2525

2626
# Extract all QA pair blocks
2727
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
2828

2929
if not qa_blocks:
3030
logger.warning("No QA pairs found in response: %s", response)
31-
return {}
31+
return qa_pairs
3232

3333
for block in qa_blocks:
3434
# Extract and clean question text
@@ -76,13 +76,13 @@ def parse_response(response: str) -> Any:
7676
)
7777
continue
7878

79-
# Build result entry with question hash as key
80-
question_hash = compute_content_hash(question)
81-
qa_pairs[question_hash] = {
82-
"question": question,
83-
"options": options, # Dict like {"A": "text", "B": "text", ...}
84-
"answer": answer, # Single letter: "A", "B", "C", or "D"
85-
}
79+
qa_pairs.append(
80+
{
81+
"question": question,
82+
"options": options, # Dict like {"A": "text", "B": "text", ...}
83+
"answer": answer, # Single letter: "A", "B", "C", or "D"
84+
}
85+
)
8686

8787
logger.debug("Successfully parsed MCQ: %s", question[:50])
8888

0 commit comments

Comments
 (0)