├── .python-version ├── src ├── tobyawesomeailibrary │ ├── __init__.py │ ├── eval_response.py │ ├── inference.py │ └── get_dataset_question.py └── textbooks_to_rl │ ├── __main__.py │ ├── textbook_manager │ ├── __init__.py │ ├── models.py │ └── textbook_manager.py │ ├── question_generator │ ├── __init__.py │ ├── llm.py │ ├── models.py │ ├── parsers.py │ ├── prompt_templates.py │ └── generator.py │ ├── __init__.py │ ├── pdf_to_text.py │ ├── example_usage.py │ ├── generate_question_for_whole_textbook.py │ ├── filter_question.py │ └── cli.py ├── tests ├── __init__.py ├── test_cli.py └── test_models.py ├── textbooks ├── pdfs │ └── .gitkeep └── txt │ └── .gitkeep ├── generated_questions └── .gitkeep ├── main.py ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── scripts ├── textbook_interface.py ├── filter.py └── process_pdfs.py ├── README.md └── pyproject.toml /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 -------------------------------------------------------------------------------- /src/tobyawesomeailibrary/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for textbooks_to_rl package.""" -------------------------------------------------------------------------------- /textbooks/pdfs/.gitkeep: -------------------------------------------------------------------------------- 1 | # This directory will contain source PDF textbook files 2 | # Add your PDF textbooks here for processing -------------------------------------------------------------------------------- /textbooks/txt/.gitkeep: -------------------------------------------------------------------------------- 1 | # This directory will contain converted textbook txt files 2 | # Add your processed textbook files here -------------------------------------------------------------------------------- /generated_questions/.gitkeep: -------------------------------------------------------------------------------- 1 | # This directory will contain generated question files 2 | # Generated questions are saved as JSON files here -------------------------------------------------------------------------------- /src/textbooks_to_rl/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Allow running the package as a module with python -m textbooks_to_rl. 3 | """ 4 | 5 | from .cli import cli_main 6 | 7 | if __name__ == "__main__": 8 | cli_main() 9 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/textbook_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import Textbook, TextbookPage 2 | from .textbook_manager import TextbookManager 3 | 4 | __all__ = ["TextbookManager", "Textbook", "TextbookPage"] 5 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import QuestionGenerator 2 | from .models import QuestionAnswer, QuestionDomain, ValidationResult 3 | 4 | __all__ = ["QuestionGenerator", "QuestionAnswer", "ValidationResult", "QuestionDomain"] 5 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Textbooks to RL - A tool for generating reinforcement learning training questions from textbooks. 3 | """ 4 | 5 | __version__ = "0.1.0" 6 | __author__ = "Toby Simonds" 7 | __email__ = "toby@example.com" 8 | 9 | from .question_generator import QuestionGenerator 10 | from .textbook_manager import TextbookManager 11 | 12 | __all__ = [ 13 | "QuestionGenerator", 14 | "TextbookManager", 15 | "__version__", 16 | ] 17 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Legacy entry point for textbooks-to-rl. 4 | This file is deprecated. Use `textbooks-to-rl` command or `python -m textbooks_to_rl.cli` instead. 5 | """ 6 | 7 | import warnings 8 | import sys 9 | 10 | warnings.warn( 11 | "main.py is deprecated. Use 'textbooks-to-rl' command or 'python -m textbooks_to_rl.cli' instead.", 12 | DeprecationWarning, 13 | stacklevel=2, 14 | ) 15 | 16 | # Import and run the new CLI 17 | from src.textbooks_to_rl.cli import cli_main 18 | 19 | if __name__ == "__main__": 20 | cli_main() 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - id: check-merge-conflict 10 | - id: debug-statements 11 | 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.9.10 14 | hooks: 15 | - id: ruff 16 | args: [--fix, --exit-non-zero-on-fix] 17 | - id: ruff-format 18 | 19 | - repo: https://github.com/pre-commit/mirrors-mypy 20 | rev: v1.13.0 21 | hooks: 22 | - id: mypy 23 | additional_dependencies: 24 | - types-requests 25 | - types-tqdm 26 | args: [--strict, --ignore-missing-imports] -------------------------------------------------------------------------------- /src/textbooks_to_rl/textbook_manager/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | 6 | @dataclass 7 | class TextbookPage: 8 | """Represents a single page from a textbook.""" 9 | 10 | content: str 11 | page_number: int 12 | textbook_name: str 13 | chapter: Optional[str] = None 14 | section: Optional[str] = None 15 | 16 | 17 | @dataclass 18 | class Textbook: 19 | """Represents a textbook with its content.""" 20 | 21 | name: str 22 | path: Path 23 | txt_path: Path 24 | pages: List[TextbookPage] 25 | total_pages: int 26 | 27 | @property 28 | def is_parsed(self) -> bool: 29 | """Check if textbook has already been parsed to txt.""" 30 | return self.txt_path.exists() 31 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/pdf_to_text.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import PyPDF2 4 | 5 | 6 | def convert_pdf_to_text(pdf_path: Path) -> str: 7 | """Convert a PDF file to text, preserving page markers.""" 8 | text = [] 9 | 10 | try: 11 | with open(pdf_path, "rb") as file: 12 | # Create PDF reader object 13 | pdf_reader = PyPDF2.PdfReader(file) 14 | 15 | # Get number of pages 16 | num_pages = len(pdf_reader.pages) 17 | 18 | # Extract text from each page 19 | for page_num in range(num_pages): 20 | # Add page marker 21 | text.append(f"\n<<>>\n") 22 | 23 | # Get page 24 | page = pdf_reader.pages[page_num] 25 | 26 | # Extract text from page 27 | text.append(page.extract_text()) 28 | 29 | return "\n".join(text) 30 | 31 | except Exception as e: 32 | raise Exception(f"Error converting PDF to text: {str(e)}") 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Toby Simonds 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/llm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from src.tobyawesomeailibrary.eval_response import evaluate_text 4 | from src.tobyawesomeailibrary.inference import generate_text 5 | 6 | 7 | class LLMInterface: 8 | """Interface for interacting with the language model.""" 9 | 10 | def __init__(self, model_name: str = "gpt-4"): 11 | self.model_name = model_name 12 | self.max_completion_tokens = 1000 # Limit completion length 13 | 14 | async def generate(self, prompt: str) -> str: 15 | """Generate text from a prompt.""" 16 | return await generate_text( 17 | model=self.model_name, prompt=prompt, max_tokens=self.max_completion_tokens 18 | ) 19 | 20 | async def evaluate( 21 | self, student_solution: str, correct_solution: str 22 | ) -> Tuple[bool, str]: 23 | """Evaluate if a solution is correct (1) or incorrect (0). 24 | Returns: 25 | Tuple[bool, str]: (is_correct, feedback) 26 | - is_correct: True if correct, False if incorrect 27 | - feedback: Explanation of the evaluation 28 | """ 29 | result = await evaluate_text( 30 | self.model_name, student_solution, correct_solution 31 | ) 32 | 33 | # Convert to binary outcome 34 | is_correct = bool(int(result[0])) # 1 -> True, 0 -> False 35 | feedback = result[1] if len(result) > 1 else None 36 | 37 | return is_correct, feedback 38 | -------------------------------------------------------------------------------- /src/tobyawesomeailibrary/eval_response.py: -------------------------------------------------------------------------------- 1 | from .inference import generate_text 2 | 3 | 4 | async def evaluate_text( 5 | eval_model: str, modelAnswer: str, groundTruthAnswer: int, temperature: float = 0 6 | ) -> str: 7 | prompt = f""" 8 | You will be given a ground truth answer and a model answer. 9 | Please output ACCURATE if the model answer matches the ground truth answer or INCORRECT otherwise. Please only return ACCURATE or INACCURATE. 10 | It is very important for my job that you do this. 11 | Be flexible with different formats of the model answer e.g decimal, fraction, integer, etc. 12 | If the answer is within rounding error of the ground truth answer, return ACCURATE. 13 | 14 | 15 | 16 | 17 | {groundTruthAnswer} 18 | 19 | 20 | 21 | {modelAnswer} 22 | 23 | """ 24 | 25 | print("Model Answer length:", len(str(modelAnswer))) 26 | print("Ground Truth Answer length:", len(str(groundTruthAnswer))) 27 | 28 | isAccurate = await generate_text( 29 | model=eval_model, 30 | prompt=prompt, 31 | max_tokens=10, # Increased to allow for full "ACCURATE" or "INACCURATE" response 32 | temperature=temperature, 33 | ) 34 | 35 | if "ACCURATE" in isAccurate.strip().upper(): 36 | return 1, 0 37 | elif "INCORRECT" in isAccurate.strip().upper(): 38 | return 0, 0 39 | else: 40 | return 0, 1 # Return 1 for badResponses 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | share/python-wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | MANIFEST 24 | 25 | # PyInstaller 26 | *.manifest 27 | *.spec 28 | 29 | # Installer logs 30 | pip-log.txt 31 | pip-delete-this-directory.txt 32 | 33 | # Unit test / coverage reports 34 | htmlcov/ 35 | .tox/ 36 | .nox/ 37 | .coverage 38 | .coverage.* 39 | .cache 40 | nosetests.xml 41 | coverage.xml 42 | *.cover 43 | *.py,cover 44 | .hypothesis/ 45 | .pytest_cache/ 46 | cover/ 47 | 48 | # Virtual environments 49 | .env 50 | .venv 51 | env/ 52 | venv/ 53 | ENV/ 54 | env.bak/ 55 | venv.bak/ 56 | .venv/ 57 | 58 | # UV 59 | .uv/ 60 | 61 | # IDEs 62 | .vscode/ 63 | .idea/ 64 | *.swp 65 | *.swo 66 | *~ 67 | 68 | # OS 69 | .DS_Store 70 | .DS_Store? 71 | ._* 72 | .Spotlight-V100 73 | .Trashes 74 | ehthumbs.db 75 | Thumbs.db 76 | 77 | # Project specific - API keys and secrets 78 | set_api_keys.sh 79 | *.key 80 | *.secret 81 | .env.local 82 | .env.development 83 | .env.test 84 | .env.production 85 | 86 | # Generated content (keep structure but ignore files) 87 | generated_questions/*.json 88 | generated_questions/old/ 89 | textbooks/txt/*.txt 90 | textbooks/pdfs/*.pdf 91 | !textbooks/txt/.gitkeep 92 | !textbooks/pdfs/.gitkeep 93 | filtered_results/ 94 | lib/datasets/ 95 | 96 | # Logs 97 | *.log 98 | logs/ 99 | 100 | # Temporary files 101 | *.tmp 102 | *.temp 103 | .tmp/ 104 | temp/ 105 | 106 | 107 | .ruff_cache 108 | .github 109 | .claude -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Optional 4 | 5 | 6 | class QuestionDomain(Enum): 7 | CALCULUS = "calculus" 8 | LINEAR_ALGEBRA = "linear_algebra" 9 | PROBABILITY = "probability" 10 | STATISTICS = "statistics" 11 | PHYSICS_MECHANICS = "physics_mechanics" 12 | PHYSICS_ELECTRICITY = "physics_electricity" 13 | PHYSICS_THERMODYNAMICS = "physics_thermodynamics" 14 | CHEMISTRY = "chemistry" 15 | BIOLOGY = "biology" 16 | COMPUTER_SCIENCE = "computer_science" 17 | OTHER = "other" 18 | 19 | 20 | class QuestionDifficulty(Enum): 21 | HIGH_SCHOOL = "high_school" 22 | UNDERGRAD = "undergrad" 23 | GRAD = "grad" 24 | PHD = "phd" 25 | EXPERT = "expert" 26 | 27 | def get_description(self) -> str: 28 | descriptions = { 29 | "high_school": "Basic concepts suitable for high school students", 30 | "undergrad": "College undergraduate level complexity", 31 | "grad": "Graduate school level depth and complexity", 32 | "phd": "Advanced theoretical concepts and research-level problems", 33 | "expert": "Industry expert level requiring deep domain knowledge", 34 | } 35 | return descriptions[self.value] 36 | 37 | 38 | @dataclass 39 | class QuestionAnswer: 40 | """Data class representing a question-answer pair with optional hints and source.""" 41 | 42 | question: str 43 | solution: str 44 | source: Optional[str] = None 45 | hints: Optional[List[str]] = None 46 | domain: Optional[QuestionDomain] = None 47 | 48 | 49 | @dataclass 50 | class ValidationResult: 51 | """Data class representing the result of a solution validation.""" 52 | 53 | is_correct: bool 54 | score: float 55 | feedback: Optional[str] = None 56 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/example_usage.py: -------------------------------------------------------------------------------- 1 | from question_generator import QuestionGenerator 2 | 3 | 4 | async def main(): 5 | # Example passage 6 | passage = """ 7 | In calculus, a sequence is a list of numbers in a definite order. 8 | The limit of a sequence {an} is the value that the terms of the sequence approach as n approaches infinity. 9 | For example, the sequence an = 1/n approaches 0 as n approaches infinity. 10 | """ 11 | 12 | # Create question generator 13 | generator = QuestionGenerator() 14 | 15 | # Generate questions with verification 16 | print("Generating questions with verification...") 17 | verified_questions = await generator.generate_questions( 18 | passage, num_questions=3, verify=True, verification_threshold=0.8 19 | ) 20 | 21 | print(f"\nGenerated {len(verified_questions)} verified questions:") 22 | for i, qa in enumerate(verified_questions, 1): 23 | print(f"\nQuestion {i}:") 24 | print(qa.question) 25 | print("\nSolution:") 26 | print(qa.solution) 27 | print("\nHints:") 28 | for j, hint in enumerate(qa.hints or [], 1): 29 | print(f"{j}. {hint}") 30 | print("-" * 80) 31 | 32 | # Generate questions without verification 33 | print("\nGenerating questions without verification...") 34 | unverified_questions = await generator.generate_questions( 35 | passage, num_questions=3, verify=False 36 | ) 37 | 38 | # Example of validating a student solution 39 | student_solution = """ 40 | The limit of 1/n as n approaches infinity is 0 because as n gets larger, 41 | 1/n becomes increasingly smaller, approaching but never reaching 0. 42 | """ 43 | 44 | validation_result = await generator.validate_solution( 45 | verified_questions[0].question, student_solution, verified_questions[0].solution 46 | ) 47 | 48 | print( 49 | f"\nStudent solution is {'correct' if validation_result.is_correct else 'incorrect'}" 50 | ) 51 | print(f"Score: {validation_result.score}") 52 | 53 | 54 | if __name__ == "__main__": 55 | import asyncio 56 | 57 | asyncio.run(main()) 58 | -------------------------------------------------------------------------------- /scripts/textbook_interface.py: -------------------------------------------------------------------------------- 1 | from lib.textbook_manager import TextbookManager 2 | from typing import Optional, List 3 | from pathlib import Path 4 | 5 | class TextbookInterface: 6 | """Interface for interacting with processed textbooks.""" 7 | 8 | def __init__(self, parsed_dir: str = "parsed_textbooks"): 9 | self.manager = TextbookManager(parsed_dir=parsed_dir) 10 | 11 | def list_available_textbooks(self) -> List[str]: 12 | """Get list of available textbooks.""" 13 | return self.manager.get_all_textbook_names() 14 | 15 | def get_page(self, textbook_name: str, page_number: int) -> Optional[str]: 16 | """Get content of a specific page.""" 17 | page = self.manager.get_page(textbook_name, page_number) 18 | return page.content if page else None 19 | 20 | def get_random_page(self, textbook_name: str) -> Optional[str]: 21 | """Get content of a random page.""" 22 | page = self.manager.get_random_page(textbook_name) 23 | return page.content if page else None 24 | 25 | def get_page_range(self, textbook_name: str, start: int, end: int) -> List[str]: 26 | """Get content from a range of pages.""" 27 | pages = self.manager.get_page_range(textbook_name, start, end) 28 | return [page.content for page in pages] 29 | 30 | def search_textbooks(self, query: str) -> List[tuple[str, int, str]]: 31 | """Search all textbooks for content.""" 32 | results = self.manager.search_content(query) 33 | return [(page.textbook_name, page.page_number, page.content) for page in results] 34 | 35 | # Example usage 36 | if __name__ == "__main__": 37 | interface = TextbookInterface() 38 | 39 | # List available textbooks 40 | print("Available textbooks:") 41 | for book in interface.list_available_textbooks(): 42 | print(f"- {book}") 43 | 44 | # Example: Get a specific page 45 | textbook = interface.list_available_textbooks()[0] # First textbook 46 | content = interface.get_page(textbook, 1) 47 | if content: 48 | print(f"\nFirst page of {textbook}:") 49 | print(content) 50 | 51 | # Example: Search 52 | results = interface.search_textbooks("calculus") 53 | print("\nSearch results for 'calculus':") 54 | for book, page, content in results: 55 | print(f"\nFound in {book}, page {page}:") 56 | print(content[:200] + "...") # First 200 chars -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TextbooksToRL 2 | 3 | A tool that automatically processes textbooks to generate structured question-answer datasets for machine learning and reinforcement learning training. 4 | 5 | ## How It Works 6 | 7 | 1. **Extract Text**: Converts PDF textbooks to text format 8 | 2. **Process Pages**: Splits content into manageable chunks (3-5 pages) 9 | 3. **Generate Questions**: Uses AI models to create questions from each chunk 10 | 4. **Verify Solutions**: Checks answers by feeding page back in with question to ensure same answer 11 | 5. **Output Dataset**: Saves structured JSON files with questions, solutions, and metadata 12 | 13 | The tool supports multiple difficulty levels (high school → PhD) and various AI models (OpenAI, Anthropic, DeepSeek, etc.). 14 | 15 | ## Setup 16 | 17 | 1. Install dependencies: 18 | ```bash 19 | uv sync 20 | ``` 21 | 22 | 2. Set up API keys: 23 | ```bash 24 | export OPENAI_API_KEY='your-key' 25 | export ANTHROPIC_API_KEY='your-key' 26 | export DEEPSEEK_API_KEY='your-key' 27 | export DEEPINFRA_API_KEY='your-key' 28 | ``` 29 | 30 | ## Usage 31 | 32 | ### Basic Usage 33 | ```bash 34 | # Generate questions from textbooks 35 | uv run textbooks-to-rl --help 36 | 37 | # Run with custom settings 38 | uv run textbooks-to-rl \ 39 | --model "Qwen/QwQ-32B" \ 40 | --output-dir "my_questions" \ 41 | --difficulty undergrad \ 42 | --verbose 43 | ``` 44 | 45 | ### Step 1: Add Textbooks 46 | Place PDF files in `textbooks/pdfs/` or process them: 47 | ```bash 48 | uv run python scripts/process_pdfs.py 49 | ``` 50 | 51 | ### Step 2: Generate Questions 52 | ```bash 53 | # Basic generation 54 | uv run textbooks-to-rl 55 | 56 | # With custom options 57 | uv run textbooks-to-rl \ 58 | --pages-per-group 5 \ 59 | --batch-size 50 \ 60 | --questions-per-chunk 8 \ 61 | --difficulty grad \ 62 | --no-verify 63 | ``` 64 | 65 | ### Step 3: Filter Questions (Optional) 66 | ```bash 67 | uv run python scripts/filter.py \ 68 | --folders generated_questions \ 69 | --output-dir filtered_results \ 70 | --model gpt-4o-mini 71 | ``` 72 | 73 | ## Options 74 | 75 | | Option | Description | Default | 76 | |--------|-------------|---------| 77 | | `--model` | AI model for generation | Qwen/QwQ-32B | 78 | | `--output-dir` | Output directory | generated_questions | 79 | | `--textbooks-dir` | Textbooks directory | textbooks/txt | 80 | | `--pages-per-group` | Pages per processing group | 3 | 81 | | `--batch-size` | Parallel batch size | 100 | 82 | | `--questions-per-chunk` | Questions per chunk | 10 | 83 | | `--difficulty` | Question difficulty level | undergrad | 84 | | `--no-verify` | Skip solution verification | False | 85 | | `--verbose` | Enable debug logging | False | 86 | 87 | ## Development 88 | 89 | ```bash 90 | # Install with dev dependencies 91 | uv sync --dev 92 | 93 | # Run code quality checks 94 | uv run ruff check src/ --fix 95 | uv run pytest 96 | ``` -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """Test the CLI interface.""" 2 | 3 | import pytest 4 | from unittest.mock import patch, MagicMock 5 | import argparse 6 | 7 | from textbooks_to_rl.cli import create_parser, setup_logging 8 | from textbooks_to_rl.question_generator.models import QuestionDifficulty 9 | 10 | 11 | def test_create_parser(): 12 | """Test parser creation.""" 13 | parser = create_parser() 14 | 15 | assert isinstance(parser, argparse.ArgumentParser) 16 | assert parser.prog == "textbooks-to-rl" 17 | 18 | # Test default values 19 | args = parser.parse_args([]) 20 | assert args.model == "Qwen/QwQ-32B" 21 | assert args.pages_per_group == 3 22 | assert args.batch_size == 100 23 | assert args.questions_per_chunk == 10 24 | assert args.difficulty == QuestionDifficulty.UNDERGRAD.value 25 | assert args.no_verify is False 26 | assert args.verbose is False 27 | 28 | 29 | def test_parser_arguments(): 30 | """Test parser with arguments.""" 31 | parser = create_parser() 32 | 33 | args = parser.parse_args([ 34 | "--model", "gpt-4o-mini", 35 | "--output-dir", "custom_output", 36 | "--pages-per-group", "5", 37 | "--batch-size", "50", 38 | "--difficulty", "grad", 39 | "--no-verify", 40 | "--verbose" 41 | ]) 42 | 43 | assert args.model == "gpt-4o-mini" 44 | assert str(args.output_dir) == "custom_output" 45 | assert args.pages_per_group == 5 46 | assert args.batch_size == 50 47 | assert args.difficulty == "grad" 48 | assert args.no_verify is True 49 | assert args.verbose is True 50 | 51 | 52 | def test_setup_logging(): 53 | """Test logging setup.""" 54 | with patch('textbooks_to_rl.cli.logging.basicConfig') as mock_basic_config: 55 | setup_logging(verbose=False) 56 | mock_basic_config.assert_called_once() 57 | 58 | # Check that INFO level was set (not DEBUG) 59 | call_args = mock_basic_config.call_args[1] 60 | assert call_args['level'] == 20 # INFO level 61 | 62 | with patch('textbooks_to_rl.cli.logging.basicConfig') as mock_basic_config: 63 | setup_logging(verbose=True) 64 | mock_basic_config.assert_called_once() 65 | 66 | # Check that DEBUG level was set 67 | call_args = mock_basic_config.call_args[1] 68 | assert call_args['level'] == 10 # DEBUG level 69 | 70 | 71 | def test_difficulty_choices(): 72 | """Test that all difficulty choices are valid.""" 73 | parser = create_parser() 74 | 75 | # Test valid difficulty levels 76 | for difficulty in QuestionDifficulty: 77 | args = parser.parse_args(["--difficulty", difficulty.value]) 78 | assert args.difficulty == difficulty.value 79 | 80 | # Test invalid difficulty level 81 | with pytest.raises(SystemExit): 82 | parser.parse_args(["--difficulty", "invalid_difficulty"]) -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | """Test the data models.""" 2 | 3 | import pytest 4 | from pathlib import Path 5 | 6 | from textbooks_to_rl.textbook_manager.models import TextbookPage, Textbook 7 | from textbooks_to_rl.question_generator.models import ( 8 | QuestionDifficulty, 9 | QuestionDomain, 10 | QuestionAnswer, 11 | ValidationResult, 12 | ) 13 | 14 | 15 | def test_textbook_page(): 16 | """Test TextbookPage model.""" 17 | page = TextbookPage( 18 | content="Sample content", 19 | page_number=1, 20 | textbook_name="test_book", 21 | chapter="Chapter 1", 22 | section="Section 1.1" 23 | ) 24 | 25 | assert page.content == "Sample content" 26 | assert page.page_number == 1 27 | assert page.textbook_name == "test_book" 28 | assert page.chapter == "Chapter 1" 29 | assert page.section == "Section 1.1" 30 | 31 | 32 | def test_textbook(): 33 | """Test Textbook model.""" 34 | test_path = Path("/test/path") 35 | textbook = Textbook( 36 | name="test_book", 37 | path=test_path, 38 | txt_path=test_path / "test.txt", 39 | pages=[], 40 | total_pages=0 41 | ) 42 | 43 | assert textbook.name == "test_book" 44 | assert textbook.path == test_path 45 | assert textbook.total_pages == 0 46 | assert len(textbook.pages) == 0 47 | 48 | 49 | def test_question_difficulty(): 50 | """Test QuestionDifficulty enum.""" 51 | assert QuestionDifficulty.HIGH_SCHOOL.value == "high_school" 52 | assert QuestionDifficulty.UNDERGRAD.value == "undergrad" 53 | assert QuestionDifficulty.GRAD.value == "grad" 54 | 55 | # Test description method 56 | desc = QuestionDifficulty.UNDERGRAD.get_description() 57 | assert "undergraduate" in desc.lower() 58 | 59 | 60 | def test_question_domain(): 61 | """Test QuestionDomain enum.""" 62 | assert QuestionDomain.CALCULUS.value == "calculus" 63 | assert QuestionDomain.LINEAR_ALGEBRA.value == "linear_algebra" 64 | assert QuestionDomain.OTHER.value == "other" 65 | 66 | 67 | def test_question_answer(): 68 | """Test QuestionAnswer model.""" 69 | qa = QuestionAnswer( 70 | question="What is 2+2?", 71 | solution="4", 72 | source="test_source", 73 | hints=["Think about basic addition"], 74 | domain=QuestionDomain.OTHER 75 | ) 76 | 77 | assert qa.question == "What is 2+2?" 78 | assert qa.solution == "4" 79 | assert qa.source == "test_source" 80 | assert qa.hints == ["Think about basic addition"] 81 | assert qa.domain == QuestionDomain.OTHER 82 | 83 | 84 | def test_validation_result(): 85 | """Test ValidationResult model.""" 86 | result = ValidationResult( 87 | is_correct=True, 88 | score=0.95, 89 | feedback="Correct answer!" 90 | ) 91 | 92 | assert result.is_correct is True 93 | assert result.score == 0.95 94 | assert result.feedback == "Correct answer!" -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/parsers.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List 3 | 4 | from .models import QuestionAnswer 5 | 6 | 7 | class ResponseParser: 8 | """Class for parsing responses from the LLM.""" 9 | 10 | @staticmethod 11 | def extract_qa_pairs(text: str) -> List[QuestionAnswer]: 12 | """Extract question-answer pairs from text that contains XML-like tags or markdown-style formatting.""" 13 | print("MODEL OUTPUT: ", text) 14 | 15 | # First try XML pattern 16 | xml_pattern = r"(?:(.*?)\s*)?(.*?)\s*(.*?)" 17 | xml_matches = re.findall(xml_pattern, text, re.DOTALL) 18 | 19 | if xml_matches: 20 | print(f"Found {len(xml_matches)} question-answer pairs using XML pattern") 21 | return ResponseParser._process_matches(xml_matches) 22 | 23 | # If no XML matches, try markdown-style pattern 24 | # Look for ### Question X followed by ### Solution X 25 | md_pattern = r"### Question (?:\d+|[A-Za-z]+)(.*?)(?:---|### Solution (?:\d+|[A-Za-z]+))(.*?)(?:---|### Question|$)" 26 | md_matches = re.findall(md_pattern, text, re.DOTALL) 27 | 28 | if md_matches: 29 | # Convert to same format as XML matches (source, question, solution) 30 | formatted_matches = [(None, q.strip(), s.strip()) for q, s in md_matches] 31 | print( 32 | f"Found {len(formatted_matches)} question-answer pairs using markdown pattern" 33 | ) 34 | return ResponseParser._process_matches(formatted_matches) 35 | 36 | print("No question-answer pairs found in text") 37 | return [] 38 | 39 | @staticmethod 40 | def _process_matches(matches): 41 | """Process matches to create QuestionAnswer objects, removing duplicates.""" 42 | # Track used questions to avoid duplicates 43 | seen_questions = set() 44 | qa_pairs = [] 45 | 46 | for match in matches: 47 | # If source is present, it's in match[0], otherwise it's an empty string or None 48 | source, question, solution = match 49 | question_clean = question.strip() 50 | 51 | # Skip if we've seen this question before 52 | if question_clean in seen_questions: 53 | continue 54 | 55 | seen_questions.add(question_clean) 56 | 57 | qa = QuestionAnswer( 58 | question=question_clean, 59 | solution=solution.strip(), 60 | source=source.strip() 61 | if source 62 | else None, # Convert empty string to None 63 | ) 64 | qa_pairs.append(qa) 65 | 66 | print( 67 | f"After removing duplicates: {len(qa_pairs)} unique questions" 68 | ) # Debug print 69 | return qa_pairs 70 | 71 | @staticmethod 72 | def extract_hints(text: str) -> List[str]: 73 | """Extract hints from formatted text.""" 74 | # Remove the tags 75 | text = text.replace("", "").replace("", "") 76 | 77 | # Split by numbered items and clean up 78 | hints = [hint.strip() for hint in re.split(r"\d+\.", text) if hint.strip()] 79 | return hints 80 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/generate_question_for_whole_textbook.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from enum import Enum 3 | from typing import List, Optional 4 | 5 | 6 | async def generate_questions_for_textbook( 7 | textbook_name: str, 8 | textbook_manager, 9 | generator, 10 | questions_per_page: int = 5, 11 | batch_size: int = 50, 12 | difficulty: Optional[Enum] = None, 13 | ) -> List: 14 | """ 15 | Generate questions for an entire textbook using batched async processing. 16 | 17 | Args: 18 | textbook_name: Name of the textbook to process 19 | textbook_manager: TextbookManager instance to handle page retrieval 20 | generator: QuestionGenerator instance 21 | questions_per_page: Number of questions to generate per page 22 | batch_size: Number of pages to process in each batch 23 | difficulty: Difficulty level for questions (QuestionDifficulty enum) 24 | 25 | Returns: 26 | List of generated questions 27 | """ 28 | 29 | async def process_page(page_num: int) -> List: 30 | passage = textbook_manager.get_page(textbook_name, page_number=page_num) 31 | if passage is None: 32 | print(f"Warning: Could not load page {page_num}") 33 | return [] 34 | 35 | print(f"Processing page {page_num}...") 36 | try: 37 | questions = await generator.generate_questions( 38 | passage.content, 39 | num_questions=questions_per_page, 40 | difficulty=difficulty, 41 | verify=False, 42 | src=passage.page_number, 43 | ) 44 | print(f"✓ Completed page {page_num} - generated {len(questions)} questions") 45 | return questions 46 | except Exception as e: 47 | print(f"Error processing page {page_num}: {str(e)}") 48 | return [] 49 | 50 | num_pages = textbook_manager.get_num_pages(textbook_name) 51 | print(f"Number of pages in textbook: {num_pages}") 52 | 53 | all_questions = [] 54 | 55 | # Process pages in batches 56 | for batch_start in range(0, num_pages, batch_size): 57 | batch_end = min(batch_start + batch_size, num_pages) 58 | print(f"\nProcessing batch of pages {batch_start} to {batch_end - 1}...") 59 | 60 | batch_results = await asyncio.gather( 61 | *[process_page(i) for i in range(batch_start, batch_end)] 62 | ) 63 | 64 | # Add batch results to all_questions 65 | for page_questions in batch_results: 66 | all_questions.extend(page_questions) 67 | 68 | print(f"Batch complete. Total questions so far: {len(all_questions)}") 69 | 70 | print( 71 | f"\nFinished processing all pages. Total questions generated: {len(all_questions)}" 72 | ) 73 | return all_questions 74 | 75 | 76 | # Example usage: 77 | """ 78 | # Import necessary dependencies 79 | from your_question_generator import QuestionGenerator, QuestionDifficulty 80 | from your_textbook_manager import TextbookManager 81 | 82 | async def main(): 83 | textbook_manager = TextbookManager() 84 | generator = QuestionGenerator(model_name="gpt-4-mini") 85 | 86 | questions = await generate_questions_for_textbook( 87 | textbook_name="your_textbook", 88 | textbook_manager=textbook_manager, 89 | generator=generator, 90 | questions_per_page=5, 91 | difficulty=QuestionDifficulty.UNDERGRAD 92 | ) 93 | 94 | # Do something with the questions... 95 | 96 | if __name__ == "__main__": 97 | asyncio.run(main()) 98 | """ 99 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "textbooks-to-rl" 7 | version = "0.1.0" 8 | description = "A tool for generating reinforcement learning training questions from textbooks" 9 | authors = [ 10 | {name = "Toby Simonds", email = "toby@example.com"}, 11 | ] 12 | readme = "README.md" 13 | license = {text = "MIT"} 14 | requires-python = ">=3.9" 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | ] 25 | keywords = ["machine learning", "textbook", "question generation", "reinforcement learning"] 26 | 27 | dependencies = [ 28 | "openai>=1.0.0", 29 | "anthropic>=0.7.0", 30 | "aiohttp>=3.8.0", 31 | "requests>=2.25.0", 32 | "tqdm>=4.65.0", 33 | "PyPDF2>=3.0.0", 34 | "pyarrow>=14.0.0", 35 | "datasets>=2.14.0", 36 | "sympy>=1.12.0", 37 | "pymupdf>=1.22.0", 38 | "typing-extensions>=4.7.0", 39 | "numpy>=1.24.0", 40 | "pandas>=2.0.0", 41 | ] 42 | 43 | [project.optional-dependencies] 44 | dev = [ 45 | "pytest>=7.0.0", 46 | "pytest-asyncio>=0.21.0", 47 | "pytest-cov>=4.0.0", 48 | "mypy>=1.7.0", 49 | "ruff>=0.1.0", 50 | "pre-commit>=3.0.0", 51 | "twine>=4.0.0", 52 | ] 53 | 54 | [project.urls] 55 | Homepage = "https://github.com/tobysimonds/textbooks-to-rl" 56 | Repository = "https://github.com/tobysimonds/textbooks-to-rl" 57 | 58 | [project.scripts] 59 | textbooks-to-rl = "textbooks_to_rl.cli:cli_main" 60 | 61 | [tool.hatch.build.targets.wheel] 62 | packages = ["src/textbooks_to_rl"] 63 | 64 | [tool.ruff] 65 | target-version = "py39" 66 | line-length = 88 67 | 68 | [tool.ruff.lint] 69 | select = [ 70 | "E", # pycodestyle errors 71 | "W", # pycodestyle warnings 72 | "F", # pyflakes 73 | "I", # isort 74 | "B", # flake8-bugbear 75 | "C4", # flake8-comprehensions 76 | "UP", # pyupgrade 77 | "SIM", # flake8-simplify 78 | ] 79 | ignore = [ 80 | "E501", # line too long (handled by black) 81 | "B008", # do not perform function calls in argument defaults 82 | "B905", # zip without an explicit strict parameter 83 | ] 84 | 85 | [tool.ruff.lint.per-file-ignores] 86 | "__init__.py" = ["F401"] 87 | 88 | [tool.mypy] 89 | python_version = "3.9" 90 | warn_return_any = true 91 | warn_unused_configs = true 92 | disallow_untyped_defs = true 93 | disallow_incomplete_defs = true 94 | check_untyped_defs = true 95 | disallow_untyped_decorators = true 96 | no_implicit_optional = true 97 | warn_redundant_casts = true 98 | warn_unused_ignores = true 99 | warn_no_return = true 100 | warn_unreachable = true 101 | strict_equality = true 102 | 103 | [[tool.mypy.overrides]] 104 | module = [ 105 | "PyPDF2.*", 106 | "fitz.*", 107 | "sympy.*", 108 | "datasets.*", 109 | "openai.*", 110 | "anthropic.*", 111 | ] 112 | ignore_missing_imports = true 113 | 114 | [tool.black] 115 | line-length = 88 116 | target-version = ['py39'] 117 | include = '\.pyi?$' 118 | 119 | [tool.pytest.ini_options] 120 | testpaths = ["tests"] 121 | python_files = ["test_*.py", "*_test.py"] 122 | python_classes = ["Test*"] 123 | python_functions = ["test_*"] 124 | addopts = [ 125 | "--strict-markers", 126 | "--strict-config", 127 | "--verbose", 128 | ] 129 | markers = [ 130 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 131 | "unit: marks tests as unit tests", 132 | "integration: marks tests as integration tests", 133 | ] 134 | 135 | [dependency-groups] 136 | dev = [ 137 | "black>=25.1.0", 138 | "mypy>=1.17.1", 139 | "pytest>=8.4.1", 140 | "pytest-asyncio>=1.1.0", 141 | "ruff>=0.12.10", 142 | ] 143 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/filter_question.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | from src.tobyawesomeailibrary.inference import generate_text 6 | 7 | 8 | class QuestionFilter: 9 | """ 10 | Filter unsolvable questions using LLM judgment. 11 | """ 12 | 13 | def __init__(self, model="gpt-4o-mini"): 14 | self.model = model 15 | self.prompt_template = """ 16 | Evaluate if the following math problem is solvable given the information provided. 17 | Question: {question} 18 | 19 | You SHOULD NOT try and solve. Just filter out respones that reference external information not provided in question e.g in figure but not provided in question. 20 | 21 | 22 | Your task: Determine if this question is solvable with ONLY the information provided. 23 | Respond with EXACTLY ONE WORD: either "True" if the question is solvable, or "False" if it's not solvable. 24 | """ 25 | 26 | async def is_solvable(self, question: Union[str, Dict]) -> Tuple[bool, str]: 27 | """ 28 | Determine if a question is solvable based on LLM judgment. 29 | 30 | Args: 31 | question: Either a question string or a dictionary containing the question 32 | 33 | Returns: 34 | Tuple of (is_solvable, response_text) 35 | """ 36 | if isinstance(question, dict): 37 | question_text = question.get("question", "") 38 | if not question_text: 39 | print(f"No question text found in {question}") 40 | return False, "No question text found" 41 | else: 42 | question_text = question 43 | 44 | # Format the prompt with the question 45 | prompt = self.prompt_template.format(question=question_text) 46 | 47 | # Get judgment from LLM 48 | try: 49 | response = await generate_text(model=self.model, prompt=prompt) 50 | response = response.strip().lower() 51 | 52 | # Check if the response contains "true" or "false" 53 | if "true" in response: 54 | return True, response 55 | elif "false" in response: 56 | return False, response 57 | else: 58 | # Default to false if response is unclear 59 | return False, f"Unclear response: {response}" 60 | except Exception as e: 61 | return False, f"Error: {str(e)}" 62 | 63 | 64 | async def filter_question_file(file_path: str, model: str = "gpt-4o-mini") -> Dict: 65 | """ 66 | Filter a single question file to determine if it's solvable. 67 | 68 | Args: 69 | file_path: Path to the JSON question file 70 | model: The LLM model to use for judgment 71 | 72 | Returns: 73 | Dictionary with the judgment results 74 | """ 75 | try: 76 | import json 77 | 78 | with open(file_path) as f: 79 | question_data = json.load(f) 80 | 81 | filter_instance = QuestionFilter(model=model) 82 | is_solvable, response = await filter_instance.is_solvable(question_data) 83 | 84 | return { 85 | "file_path": file_path, 86 | "is_solvable": is_solvable, 87 | "response": response, 88 | } 89 | 90 | except Exception as e: 91 | return { 92 | "file_path": file_path, 93 | "is_solvable": False, 94 | "response": f"Error: {str(e)}", 95 | } 96 | 97 | 98 | async def filter_directory( 99 | directory_path: str, output_path: Optional[str] = None, model: str = "gpt-4o-mini" 100 | ) -> Dict: 101 | """ 102 | Filter all question files in a directory. 103 | 104 | Args: 105 | directory_path: Path to directory containing question JSON files 106 | output_path: Optional path to save filtered results 107 | model: The LLM model to use for judgment 108 | 109 | Returns: 110 | Dictionary with lists of solvable and unsolvable questions 111 | """ 112 | result = {"solvable": [], "unsolvable": [], "responses": {}} 113 | 114 | tasks = [] 115 | 116 | for filename in os.listdir(directory_path): 117 | if filename.endswith(".json"): 118 | file_path = os.path.join(directory_path, filename) 119 | tasks.append(filter_question_file(file_path, model)) 120 | 121 | # Run all filtering tasks concurrently 122 | judgments = await asyncio.gather(*tasks) 123 | 124 | # Process results 125 | for judgment in judgments: 126 | filename = os.path.basename(judgment["file_path"]) 127 | if judgment["is_solvable"]: 128 | result["solvable"].append(filename) 129 | else: 130 | result["unsolvable"].append(filename) 131 | result["responses"][filename] = judgment["response"] 132 | 133 | # Save results if output path is specified 134 | if output_path: 135 | import json 136 | 137 | with open(output_path, "w") as f: 138 | json.dump(result, f, indent=2) 139 | 140 | return result 141 | -------------------------------------------------------------------------------- /src/tobyawesomeailibrary/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import anthropic 4 | from openai import AsyncOpenAI 5 | 6 | # API Keys 7 | openai_api_key = os.getenv("OPENAI_API_KEY") 8 | anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") 9 | deepseek_api_key = os.getenv("DEEPSEEK_API_KEY") 10 | deepinfra_api_key = os.getenv("DEEPINFRA_API_KEY") 11 | 12 | # Initialize clients only when keys are available 13 | async_openai_client = None 14 | if openai_api_key: 15 | async_openai_client = AsyncOpenAI(api_key=openai_api_key) 16 | 17 | deepseek_client = None 18 | if deepseek_api_key: 19 | deepseek_client = AsyncOpenAI( 20 | api_key=deepseek_api_key, base_url="https://api.deepseek.com" 21 | ) 22 | 23 | deepinfra_client = None 24 | if deepinfra_api_key: 25 | deepinfra_client = AsyncOpenAI( 26 | api_key=deepinfra_api_key, base_url="https://api.deepinfra.com/v1/openai" 27 | ) 28 | 29 | 30 | async def generate_text( 31 | model: str, prompt: str, max_tokens: int = 8000, temperature: float = 0 32 | ) -> str: 33 | """ 34 | Asynchronously generate text using various AI models. 35 | 36 | :param model: The name of the model to use (e.g., "gpt-3.5-turbo", "claude-2", "meta-llama/Llama-2-70b-chat-hf") 37 | :param prompt: The input prompt for text generation 38 | :param max_tokens: Maximum number of tokens to generate 39 | :param temperature: Controls randomness in generation (0.0 to 1.0) 40 | :return: Generated text as a string 41 | """ 42 | 43 | # OpenAI models 44 | if model.startswith("gpt-") or model.startswith("o1"): 45 | if not async_openai_client: 46 | raise ValueError( 47 | "OpenAI API key not set or invalid. Set the OPENAI_API_KEY environment variable." 48 | ) 49 | 50 | response = await async_openai_client.chat.completions.create( 51 | model=model, 52 | messages=[{"role": "user", "content": prompt}], 53 | ) 54 | return response.choices[0].message.content.strip() 55 | 56 | elif model.startswith("ft:gpt") or model.startswith("o1"): 57 | if not async_openai_client: 58 | raise ValueError( 59 | "OpenAI API key not set or invalid. Set the OPENAI_API_KEY environment variable." 60 | ) 61 | 62 | response = await async_openai_client.chat.completions.create( 63 | model=model, 64 | messages=[{"role": "user", "content": prompt}], 65 | max_tokens=max_tokens, 66 | temperature=temperature, 67 | ) 68 | return response.choices[0].message.content.strip() 69 | 70 | # Anthropic (Claude) models 71 | elif model.startswith("claude-"): 72 | if not anthropic_api_key: 73 | raise ValueError( 74 | "Anthropic API key not set or invalid. Set the ANTHROPIC_API_KEY environment variable." 75 | ) 76 | 77 | async def run_anthropic(): 78 | client = anthropic.Anthropic(api_key=anthropic_api_key) 79 | if model.startswith("claude-3"): 80 | response = client.messages.create( 81 | model=model, 82 | messages=[{"role": "user", "content": prompt}], 83 | max_tokens=max_tokens, 84 | temperature=temperature, 85 | ) 86 | return response.content[0].text.strip() 87 | else: 88 | response = client.completions.create( 89 | model=model, 90 | prompt=f"Human: {prompt}\n\nAssistant:", 91 | max_tokens_to_sample=max_tokens, 92 | temperature=temperature, 93 | ) 94 | return response.completion.strip() 95 | 96 | return await run_anthropic() 97 | 98 | # DeepInfra models 99 | elif ( 100 | model.startswith("meta-llama/") 101 | or model.startswith("deepseek-ai") 102 | or model.startswith("Qwen/") 103 | or model.startswith("Meta-Llama") 104 | ): 105 | if not deepinfra_client: 106 | raise ValueError( 107 | "DeepInfra API key not set or invalid. Set the DEEPINFRA_API_KEY environment variable." 108 | ) 109 | 110 | response = await deepinfra_client.chat.completions.create( 111 | model=model, 112 | messages=[{"role": "user", "content": prompt}], 113 | max_tokens=max_tokens, 114 | temperature=temperature, 115 | ) 116 | return response.choices[0].message.content.strip() 117 | 118 | # DeepSeek models 119 | elif model.startswith("deepseek-"): 120 | if not deepseek_client: 121 | raise ValueError( 122 | "DeepSeek API key not set or invalid. Set the DEEPSEEK_API_KEY environment variable." 123 | ) 124 | 125 | try: 126 | response = await deepseek_client.chat.completions.create( 127 | model=model, 128 | messages=[{"role": "user", "content": prompt}], 129 | max_tokens=max_tokens, 130 | temperature=temperature, 131 | ) 132 | return response.choices[0].message.content.strip() 133 | except Exception as e: 134 | print(f"An error occurred while generating text with DeepSeek model: {e}") 135 | raise 136 | 137 | else: 138 | raise ValueError(f"Unsupported model: {model}") 139 | -------------------------------------------------------------------------------- /scripts/filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import asyncio 4 | import argparse 5 | from typing import List, Dict, Any 6 | from tqdm import tqdm # For progress bar 7 | from lib.filter_question import QuestionFilter 8 | 9 | async def process_batch(files: List[str], output_dir: str, model: str = "gpt-4o-mini") -> None: 10 | """ 11 | Process a batch of question files and filter them. 12 | 13 | Args: 14 | files: List of file paths to process 15 | output_dir: Directory to save filtered results 16 | model: Model to use for filtering 17 | """ 18 | filter_instance = QuestionFilter(model=model) 19 | tasks = [] 20 | 21 | for file_path in files: 22 | tasks.append(process_file(file_path, filter_instance, output_dir)) 23 | 24 | await asyncio.gather(*tasks) 25 | 26 | async def process_file(file_path: str, filter_instance: QuestionFilter, output_dir: str) -> None: 27 | """ 28 | Process a single question file and save if it's unsolvable. 29 | 30 | Args: 31 | file_path: Path to the question JSON file 32 | filter_instance: Instance of QuestionFilter 33 | output_dir: Directory to save filtered results 34 | """ 35 | try: 36 | with open(file_path, 'r') as f: 37 | question_data = json.load(f) 38 | 39 | # Check if question is solvable 40 | is_solvable, response = await filter_instance.is_solvable(question_data) 41 | # print(is_solvable) 42 | # If the question is NOT solvable, save it to the output directory 43 | 44 | # Add filtered flag to the question data 45 | question_data["filtered"] = True 46 | question_data["filter_response"] = is_solvable 47 | 48 | # Create output file path 49 | filename = os.path.basename(file_path) 50 | output_file_path = os.path.join(output_dir, filename) 51 | 52 | # Save the filtered question 53 | with open(output_file_path, 'w') as f: 54 | json.dump(question_data, f, indent=4) 55 | 56 | except Exception as e: 57 | print(f"Error processing {file_path}: {e}") 58 | 59 | async def filter_folder(folder_path: str, output_base_dir: str, model: str = "gpt-4o-mini", batch_size: int = 200) -> Dict[str, int]: 60 | """ 61 | Filter all question files in a folder. 62 | 63 | Args: 64 | folder_path: Path to folder containing question JSON files 65 | output_base_dir: Base directory for saving filtered results 66 | model: Model to use for filtering 67 | batch_size: Number of files to process in a batch 68 | 69 | Returns: 70 | Statistics about processing 71 | """ 72 | # Folder name is the last part of the path 73 | folder_name = os.path.basename(folder_path) 74 | 75 | # Create output directory 76 | output_dir = os.path.join(output_base_dir, f"filtered-{folder_name}") 77 | os.makedirs(output_dir, exist_ok=True) 78 | 79 | # Get all JSON files in the folder 80 | json_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.json')] 81 | 82 | print(f"Processing {len(json_files)} files from {folder_name}") 83 | 84 | # Process files in batches 85 | stats = { 86 | "total": len(json_files), 87 | "processed": 0, 88 | "folder": folder_name 89 | } 90 | 91 | # Process in batches with progress bar 92 | for i in tqdm(range(0, len(json_files), batch_size), desc=f"Processing {folder_name}"): 93 | batch = json_files[i:i+batch_size] 94 | await process_batch(batch, output_dir, model) 95 | stats["processed"] += len(batch) 96 | 97 | # Count how many files were filtered 98 | filtered_files = [f for f in os.listdir(output_dir) if f.endswith('.json')] 99 | stats["filtered"] = len(filtered_files) 100 | 101 | return stats 102 | 103 | async def main(): 104 | """Main entry point for the script.""" 105 | parser = argparse.ArgumentParser(description="Filter unsolvable questions across multiple folders") 106 | parser.add_argument("--folders", nargs="+", required=True, help="List of folders to process") 107 | parser.add_argument("--output-dir", default="filtered_questions", help="Base directory for output") 108 | parser.add_argument("--model", default="gpt-4o-mini", help="Model to use for filtering") 109 | parser.add_argument("--batch-size", type=int, default=200, help="Batch size for processing") 110 | args = parser.parse_args() 111 | 112 | # Create base output directory 113 | os.makedirs(args.output_dir, exist_ok=True) 114 | 115 | # Process each folder 116 | all_stats = [] 117 | for folder in args.folders: 118 | stats = await filter_folder(folder, args.output_dir, args.model, args.batch_size) 119 | all_stats.append(stats) 120 | print(f"Folder {stats['folder']}: Processed {stats['processed']}/{stats['total']}, Filtered: {stats['filtered']}") 121 | 122 | # Save overall statistics 123 | with open(os.path.join(args.output_dir, "filter_stats.json"), 'w') as f: 124 | json.dump(all_stats, f, indent=4) 125 | 126 | print("Filtering complete!") 127 | print(f"Total folders processed: {len(all_stats)}") 128 | print(f"Total questions processed: {sum(s['total'] for s in all_stats)}") 129 | print(f"Total questions filtered: {sum(s['filtered'] for s in all_stats)}") 130 | 131 | if __name__ == "__main__": 132 | asyncio.run(main()) -------------------------------------------------------------------------------- /src/textbooks_to_rl/textbook_manager/textbook_manager.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from pathlib import Path 4 | from typing import Dict, List, Optional 5 | 6 | import PyPDF2 7 | 8 | from .models import Textbook, TextbookPage 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class TextbookManager: 14 | """Manages textbook content and access.""" 15 | 16 | def __init__( 17 | self, 18 | textbooks_dir: str = "textbooks/txt", 19 | pdf_dir: Optional[str] = None, 20 | ) -> None: 21 | self.textbooks_dir = Path(textbooks_dir) 22 | self.pdf_dir = Path(pdf_dir) if pdf_dir else None 23 | self.textbooks: Dict[str, Textbook] = {} 24 | 25 | # Create directories if they don't exist 26 | self.textbooks_dir.mkdir(exist_ok=True, parents=True) 27 | if self.pdf_dir: 28 | self.pdf_dir.mkdir(exist_ok=True, parents=True) 29 | 30 | # Load all textbooks on initialization 31 | self._load_textbooks() 32 | 33 | def _load_textbooks(self) -> None: 34 | """Load all textbooks from the txt directory.""" 35 | txt_files = list(self.textbooks_dir.glob("*.txt")) 36 | 37 | for txt_path in txt_files: 38 | name = txt_path.stem 39 | 40 | # Create Textbook object 41 | textbook = Textbook( 42 | name=name, path=txt_path, txt_path=txt_path, pages=[], total_pages=0 43 | ) 44 | 45 | self._load_text_content(textbook) 46 | self.textbooks[name] = textbook 47 | 48 | def _load_text_content(self, textbook: Textbook) -> None: 49 | """Load content from txt file and split into pages.""" 50 | try: 51 | with open(textbook.txt_path, encoding="utf-8") as f: 52 | content = f.read() 53 | 54 | # Split content into roughly page-sized chunks (3000 characters each) 55 | chunk_size = 3000 56 | chunks = [ 57 | content[i : i + chunk_size] for i in range(0, len(content), chunk_size) 58 | ] 59 | 60 | # Create pages 61 | textbook.pages = [ 62 | TextbookPage( 63 | content=chunk, page_number=i + 1, textbook_name=textbook.name 64 | ) 65 | for i, chunk in enumerate(chunks) 66 | ] 67 | textbook.total_pages = len(textbook.pages) 68 | 69 | except Exception as e: 70 | logger.error(f"Error loading textbook {textbook.name}: {str(e)}") 71 | raise 72 | 73 | def get_textbook(self, name: str) -> Optional[Textbook]: 74 | """Get a textbook by name.""" 75 | return self.textbooks.get(name) 76 | 77 | def get_page(self, textbook_name: str, page_number: int) -> Optional[TextbookPage]: 78 | """Get a specific page from a textbook.""" 79 | textbook = self.get_textbook(textbook_name) 80 | if not textbook or page_number < 1 or page_number > textbook.total_pages: 81 | return None 82 | return textbook.pages[page_number - 1] 83 | 84 | def get_random_page(self, textbook_name: str) -> Optional[TextbookPage]: 85 | """Get a random page from a textbook.""" 86 | textbook = self.get_textbook(textbook_name) 87 | if not textbook or not textbook.pages: 88 | return None 89 | return random.choice(textbook.pages) 90 | 91 | def get_all_textbook_names(self) -> List[str]: 92 | """Get list of all textbook names.""" 93 | return list(self.textbooks.keys()) 94 | 95 | def get_all_textbooks(self) -> List[Textbook]: 96 | """Get list of all textbook objects. 97 | 98 | Returns: 99 | List of all Textbook objects managed by this TextbookManager 100 | """ 101 | logger.debug(f"Textbooks dictionary contains {len(self.textbooks)} items") 102 | textbooks_list = list(self.textbooks.values()) 103 | logger.debug(f"Returning {len(textbooks_list)} textbooks") 104 | return textbooks_list 105 | 106 | def process_directory(self) -> None: 107 | """Process all PDFs in the pdf directory to text files.""" 108 | if not self.pdf_dir: 109 | raise ValueError("PDF directory not specified") 110 | 111 | pdf_files = list(self.pdf_dir.glob("*.pdf")) 112 | for pdf_path in pdf_files: 113 | try: 114 | # Process PDF to text 115 | output_path = self.textbooks_dir / f"{pdf_path.stem}.txt" 116 | self._process_pdf(pdf_path, output_path) 117 | logger.info(f"Processed {pdf_path.name}") 118 | except Exception as e: 119 | logger.error(f"Error processing {pdf_path.name}: {str(e)}") 120 | raise 121 | 122 | def _process_pdf(self, pdf_path: Path, output_path: Path) -> None: 123 | """Convert a PDF file to text.""" 124 | try: 125 | with open(pdf_path, "rb") as pdf_file: 126 | # Create PDF reader object 127 | pdf_reader = PyPDF2.PdfReader(pdf_file) 128 | 129 | # Extract text from all pages 130 | text = [] 131 | for page in pdf_reader.pages: 132 | text.append(page.extract_text()) 133 | 134 | # Write combined text to output file 135 | with open(output_path, "w", encoding="utf-8") as txt_file: 136 | txt_file.write("\n".join(text)) 137 | 138 | except Exception as e: 139 | logger.error(f"Error processing PDF {pdf_path}: {str(e)}") 140 | raise 141 | 142 | def get_num_pages(self, textbook_name: str) -> Optional[int]: 143 | """Get the total number of pages in a textbook. 144 | 145 | Args: 146 | textbook_name: Name of the textbook 147 | 148 | Returns: 149 | Total number of pages if textbook exists, None otherwise 150 | """ 151 | textbook = self.get_textbook(textbook_name) 152 | if not textbook: 153 | return None 154 | return textbook.total_pages 155 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/prompt_templates.py: -------------------------------------------------------------------------------- 1 | from .models import QuestionDifficulty 2 | 3 | 4 | class QuestionPromptTemplates: 5 | """Class containing templates for question generation prompts.""" 6 | 7 | @staticmethod 8 | def question_generation( 9 | text_book_snippet: str, 10 | num_questions: int, 11 | difficulty: QuestionDifficulty = QuestionDifficulty.GRAD, 12 | ) -> str: 13 | return f""" 14 | Can you extract between {num_questions} and {num_questions * 3} questions and solutions for me from the textbook snippet? 15 | Extract questions directly from worked examples, derivations, and any existing problems in the text. 16 | All questions and solutions should be present in the textbook snippet. 17 | 18 | Focus on numerical answers where appropriate or simple verifiable solutions for non-numerical answers e.g Mitochondria, and the final answer should be written in a box. 19 | The answer should remain in fractional form (no decimal conversion). 20 | 21 | Each question should be fully self-contained and include all the context needed from the snippet, because I won't have access to the entire textbook excerpt when I'm looking at individual questions. 22 | Where possible solution should just be boxed answer no additioanl explanation text so that it can be automatically graded. 23 | Use the specific details, data, or numerical values provided in the snippet—no inventing new details. 24 | 25 | Make sure each question are really challenging, drawn from the harder examples or worked solutions within the text. 26 | Avoid trivial or too-simple problems. 27 | Strictly use what's already in the textbook snippet—do not create new or original problems. 28 | 29 | Don't just point to theorems, lemmas equations etc without including them in the question. Remember I won't have access to the textbook snippet when I'm taking the quiz. 30 | 31 | For each question, follow the exact format: 32 | 33 | [Include the complete problem statement here, with all required details from the text so it's self-contained.] 34 | 35 | 36 | [Include only the boxed solution as given in the textbook. Do not add extra commentary—just the final boxed answer from the text.] 37 | 38 | 39 | Textbook Snippet: 40 | {text_book_snippet} 41 | """ 42 | 43 | # return f""" 44 | # I'm trying to quiz myself on this text to help me learn 45 | 46 | # can you generate between {num_questions} and {num_questions * 3} questions and solutions for me from the textbook snippet? 47 | 48 | # Can you focus on problems that have numerical answers and can you output the answer in a box? 49 | 50 | # Leave final answer as fraction no need to express as decimal. 51 | 52 | # Make sure each problem has the full context of the problem in the question. 53 | 54 | # Don't make up new fake questions, only use the information provided in the textbook snippet. Make sure the solution is present in the textbook snippet. 55 | # Make sure all the needed information to answer the question is present in question. I won't have access to the textbook snippet when I take the quiz. 56 | # Each question should be self contained and should contain all the context and information needed. I won't see the other questions when I take the quiz only one at a time and the order randomized. 57 | # You should use all examples given in the textbook snippets as sample problems. 58 | 59 | # Focus on very hard example and problems from the textbook. Don't include super easy ones. 60 | # Don't make up new problems just problems from the textbook. 61 | 62 | # For each problem, follow the exact format below: 63 | # 64 | # [Write your unique math problem here at. Ensure that the problem statement includes all necessary details and context from the snippet so it is fully self-contained.] 65 | # 66 | # 67 | # [Write your detailed solution here. Ensure that if a final numerical answer is provided, it is enclosed in a box.] 68 | # 69 | # 70 | # [Write your unique math problem here at. Ensure that the problem statement includes all necessary details and context from the snippet so it is fully self-contained.] 71 | # 72 | # 73 | # [Write your detailed solution here. Ensure that if a final numerical answer is provided, it is enclosed in a box.] 74 | # 75 | 76 | # Textbook Snippet: 77 | # {text_book_snippet} 78 | 79 | # """ 80 | 81 | # return f""" 82 | # You are given a textbook snippet. Your task is to generate questions and solutions based solely on the provided snippet. Follow these instructions precisely: 83 | 84 | # 1. **Source Dependency:** 85 | # - Only use information directly from the snippet. Do not infer or include any external information. 86 | # - Every question must be verifiable and solvable using only the information contained in the snippet. 87 | 88 | # 2. **Question and Quantity Requirements:** 89 | # - Generate between {num_questions} and {num_questions * 3} UNIQUE and DIFFERENT questions. 90 | # - If the snippet does not support {num_questions} distinct questions, generate only as many as can be fully supported by the text. 91 | 92 | # 3. **Difficulty Level Specification:** 93 | # - Match the complexity and depth of analysis to {difficulty.value} level. 94 | # - Include appropriate mathematical rigor and analytical reasoning expected at this level. 95 | 96 | # 4. **Question Self-Containment:** 97 | # - Ensure each question is fully self-contained. Include all numerical values, constants, and context from the snippet necessary to solve the question. 98 | # - If the question involves a derivation or an expression, show the initial expression and the steps leading to the final answer. 99 | 100 | # 5. **Answer Presentation:** 101 | # - If the solution yields a final numerical answer, enclose that result in a box (using LaTeX or a similar method). 102 | # - Keep solutions brief and to the point while ensuring all necessary details are included. 103 | 104 | # 6. **Response Format:** 105 | # For each question, repeat the following format EXACTLY: 106 | 107 | # 108 | # [Paste only the relevant part of the textbook snippet that directly supports and contains the solution for this question] 109 | # 110 | # 111 | # [Write your unique question here at {difficulty.value} level] 112 | # 113 | # 114 | # [Write your detailed solution here matching {difficulty.value} level expectations] 115 | # 116 | 117 | # 7. **No Extraneous Information:** 118 | # - Do not include any commentary, introductions, or additional text outside the prescribed format. 119 | # - Every piece of output must strictly adhere to the structure above. 120 | 121 | # Textbook Snippet: 122 | # {text_book_snippet} 123 | 124 | # Begin generating the questions and solutions now. 125 | 126 | # """ 127 | 128 | @staticmethod 129 | def hint_generation( 130 | question: str, difficulty: QuestionDifficulty = QuestionDifficulty.UNDERGRAD 131 | ) -> str: 132 | return f""" 133 | You are a helpful teaching assistant working with {difficulty.value} level students. 134 | A student is struggling with this {difficulty.value} level question. 135 | Generate 2-3 helpful hints that will guide them toward the solution without giving it away. 136 | 137 | The hints should: 138 | 1. Break down the problem-solving approach into steps appropriate for {difficulty.value} level 139 | 2. Point out key concepts or equations to consider 140 | 3. Suggest what to focus on first 141 | 4. Match the theoretical depth expected at {difficulty.value} level 142 | 143 | Question: 144 | {question} 145 | 146 | Format your response as: 147 | 148 | 1. [First hint appropriate for {difficulty.value} level] 149 | 2. [Second hint with {difficulty.value} appropriate complexity] 150 | 3. [Optional third hint if needed] 151 | 152 | """ 153 | -------------------------------------------------------------------------------- /scripts/process_pdfs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from pathlib import Path 4 | import fitz # PyMuPDF 5 | from concurrent.futures import ThreadPoolExecutor 6 | from tqdm import tqdm 7 | import traceback 8 | import os 9 | 10 | def setup_logging(): 11 | """Setup logging to both file and console with more detailed formatting.""" 12 | # Clear previous log file if it exists 13 | log_path = Path('pdf_processing.log') 14 | if log_path.exists(): 15 | log_path.unlink() 16 | 17 | # Create formatter 18 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 19 | 20 | # Setup file handler 21 | file_handler = logging.FileHandler(log_path) 22 | file_handler.setFormatter(formatter) 23 | 24 | # Setup console handler 25 | console_handler = logging.StreamHandler(sys.stdout) 26 | console_handler.setFormatter(formatter) 27 | 28 | # Setup logger 29 | logger = logging.getLogger() 30 | logger.setLevel(logging.INFO) 31 | 32 | # Remove existing handlers if any 33 | for handler in logger.handlers[:]: 34 | logger.removeHandler(handler) 35 | 36 | logger.addHandler(file_handler) 37 | logger.addHandler(console_handler) 38 | 39 | return logger 40 | 41 | def extract_text_with_pymupdf(pdf_path: Path, output_path: Path) -> bool: 42 | """Extract text from PDF using PyMuPDF (fitz). 43 | 44 | Args: 45 | pdf_path: Path to the PDF file 46 | output_path: Path where to save the text file 47 | 48 | Returns: 49 | bool: True if successful, False otherwise 50 | """ 51 | try: 52 | # Verify file exists before attempting to open 53 | if not pdf_path.exists(): 54 | logging.error(f"File does not exist: {pdf_path}") 55 | return False 56 | 57 | # Open the PDF 58 | doc = fitz.open(pdf_path) 59 | 60 | extracted_text = [] 61 | # Process each page 62 | for page_num in range(len(doc)): 63 | page = doc[page_num] 64 | # Extract text with improved layout preservation 65 | text = page.get_text() 66 | if text.strip(): # Only add if text is not empty 67 | extracted_text.append(text) 68 | 69 | # Check if we got anything meaningful 70 | if not extracted_text: 71 | logging.warning(f"No text extracted from {pdf_path.name}") 72 | doc.close() 73 | return False 74 | 75 | # Write combined text to output file 76 | with open(output_path, 'w', encoding='utf-8', errors='replace') as txt_file: 77 | txt_file.write('\n\n'.join(extracted_text)) 78 | 79 | doc.close() 80 | return True 81 | 82 | except Exception as e: 83 | logging.error(f"PyMuPDF extraction failed for {pdf_path.name}: {str(e)}") 84 | logging.debug(traceback.format_exc()) 85 | return False 86 | 87 | def process_pdf(pdf_path: Path, output_path: Path) -> bool: 88 | """Process a single PDF file to text. 89 | 90 | Args: 91 | pdf_path: Path to the PDF file 92 | output_path: Path where to save the text file 93 | 94 | Returns: 95 | bool: True if successful, False otherwise 96 | """ 97 | try: 98 | logging.info(f"Processing: {pdf_path}") 99 | 100 | # First attempt: PyMuPDF 101 | if extract_text_with_pymupdf(pdf_path, output_path): 102 | logging.info(f"Successfully processed {pdf_path.name} with PyMuPDF") 103 | 104 | # Quick verification - check file size and content 105 | if output_path.stat().st_size > 100: # Arbitrary minimum size 106 | with open(output_path, 'r', encoding='utf-8', errors='replace') as f: 107 | # Check first few lines for actual content 108 | sample = f.read(1000) 109 | if len(sample.split()) > 20: # At least some words 110 | return True 111 | 112 | logging.warning(f"Extraction produced too little content for {pdf_path.name}") 113 | 114 | # If we get here, the primary method failed or produced insufficient content 115 | logging.warning(f"Could not extract meaningful text from {pdf_path.name}") 116 | return False 117 | 118 | except Exception as e: 119 | logging.error(f"Error processing PDF {pdf_path.name}: {str(e)}") 120 | logging.debug(traceback.format_exc()) 121 | return False 122 | 123 | def worker(pdf_file, input_path, output_path): 124 | """Worker function for thread pool.""" 125 | try: 126 | pdf_path = input_path / pdf_file.name # Use name to create proper path 127 | output_file = output_path / f"{pdf_file.stem}.txt" 128 | success = process_pdf(pdf_path, output_file) 129 | return pdf_file.name, success 130 | except Exception as e: 131 | logging.error(f"Worker failed for {pdf_file.name}: {str(e)}") 132 | return pdf_file.name, False 133 | 134 | def process_pdfs(input_dir: str = "textbooks/pdf", output_dir: str = "textbooks/txt", max_workers: int = 4): 135 | """Process all PDFs in input directory to text files in output directory with parallel processing. 136 | 137 | Args: 138 | input_dir: Directory containing PDF files 139 | output_dir: Directory for output text files 140 | max_workers: Maximum number of worker threads 141 | """ 142 | logger = setup_logging() 143 | 144 | # Create Path objects - resolve to absolute paths to avoid path issues 145 | input_path = Path(input_dir).resolve() 146 | output_path = Path(output_dir).resolve() 147 | 148 | # Create output directory if it doesn't exist 149 | output_path.mkdir(parents=True, exist_ok=True) 150 | 151 | # Debug path information 152 | logging.info(f"Input directory (absolute): {input_path}") 153 | logging.info(f"Output directory (absolute): {output_path}") 154 | 155 | try: 156 | # Get all PDF files directly from the directory 157 | pdf_files = list(input_path.glob("*.pdf")) 158 | 159 | if not pdf_files: 160 | logging.warning(f"No PDF files found in {input_dir}") 161 | # Try listing directory contents for debugging 162 | all_files = list(input_path.iterdir()) 163 | logging.info(f"Directory contents: {[f.name for f in all_files]}") 164 | return 165 | 166 | logging.info(f"Found {len(pdf_files)} PDF files to process") 167 | 168 | # Process a single PDF first as a test case 169 | if pdf_files: 170 | test_pdf = pdf_files[0] 171 | logging.info(f"Testing with first PDF: {test_pdf.name}") 172 | test_output = output_path / f"{test_pdf.stem}_test.txt" 173 | success = process_pdf(test_pdf, test_output) 174 | if success: 175 | logging.info("Test processing successful, continuing with batch processing") 176 | else: 177 | logging.warning("Test processing failed, checking for issues before continuing") 178 | 179 | # Process PDFs in parallel 180 | results = {"success": 0, "failed": 0} 181 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 182 | # Submit all tasks and wrap with tqdm for progress bar 183 | futures = {executor.submit(worker, pdf_file, input_path, output_path): pdf_file.name 184 | for pdf_file in pdf_files} 185 | 186 | # Show progress bar 187 | for future in tqdm(futures, desc="Processing PDFs", unit="file"): 188 | pdf_name, success = future.result() 189 | if success: 190 | results["success"] += 1 191 | else: 192 | results["failed"] += 1 193 | 194 | # Log summary 195 | logger.info(f"PDF processing complete! Successfully processed: {results['success']}, Failed: {results['failed']}") 196 | 197 | # List failed files if any 198 | if results["failed"] > 0: 199 | failed_files = [f for f in output_path.glob("*.txt") if f.stat().st_size < 100] 200 | if failed_files: 201 | logger.warning(f"Files with potentially insufficient content: {', '.join(f.name for f in failed_files)}") 202 | 203 | except Exception as e: 204 | logger.error(f"Error during processing: {str(e)}") 205 | logger.debug(traceback.format_exc()) 206 | raise 207 | 208 | if __name__ == "__main__": 209 | import argparse 210 | parser = argparse.ArgumentParser(description='Process PDFs to text files') 211 | parser.add_argument('--textbooks-dir', default='textbooks/pdf', help='Directory containing PDF textbooks') 212 | parser.add_argument('--parsed-dir', default='textbooks/txt', help='Directory for parsed text files') 213 | parser.add_argument('--max-workers', type=int, default=4, help='Maximum number of parallel workers') 214 | 215 | args = parser.parse_args() 216 | process_pdfs(args.textbooks_dir, args.parsed_dir, args.max_workers) -------------------------------------------------------------------------------- /src/textbooks_to_rl/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Command-line interface for textbooks-to-rl. 3 | """ 4 | 5 | import argparse 6 | import asyncio 7 | import logging 8 | import sys 9 | from pathlib import Path 10 | 11 | from . import __version__ 12 | from .question_generator import QuestionGenerator 13 | from .question_generator.models import QuestionDifficulty 14 | from .textbook_manager import TextbookManager 15 | 16 | 17 | def setup_logging(verbose: bool = False) -> None: 18 | """Set up logging configuration.""" 19 | level = logging.DEBUG if verbose else logging.INFO 20 | logging.basicConfig( 21 | level=level, 22 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S", 24 | ) 25 | 26 | 27 | def create_parser() -> argparse.ArgumentParser: 28 | """Create and return the argument parser.""" 29 | parser = argparse.ArgumentParser( 30 | prog="textbooks-to-rl", 31 | description="Generate questions from textbooks for reinforcement learning", 32 | formatter_class=argparse.RawDescriptionHelpFormatter, 33 | ) 34 | 35 | parser.add_argument( 36 | "--version", 37 | action="version", 38 | version=f"%(prog)s {__version__}", 39 | ) 40 | 41 | parser.add_argument( 42 | "--model", 43 | default="Qwen/QwQ-32B", 44 | help="Model to use for generation (default: %(default)s)", 45 | ) 46 | 47 | parser.add_argument( 48 | "--output-dir", 49 | type=Path, 50 | default=Path("generated_questions"), 51 | help="Directory to save generated questions (default: %(default)s)", 52 | ) 53 | 54 | parser.add_argument( 55 | "--textbooks-dir", 56 | type=Path, 57 | default=Path("textbooks/txt"), 58 | help="Directory containing textbook txt files (default: %(default)s)", 59 | ) 60 | 61 | parser.add_argument( 62 | "--pages-per-group", 63 | type=int, 64 | default=3, 65 | help="Number of pages to process together (default: %(default)s)", 66 | ) 67 | 68 | parser.add_argument( 69 | "--batch-size", 70 | type=int, 71 | default=100, 72 | help="Number of page groups to process in parallel (default: %(default)s)", 73 | ) 74 | 75 | parser.add_argument( 76 | "--questions-per-chunk", 77 | type=int, 78 | default=10, 79 | help="Number of questions to generate per chunk (default: %(default)s)", 80 | ) 81 | 82 | parser.add_argument( 83 | "--difficulty", 84 | type=str, 85 | choices=[d.value for d in QuestionDifficulty], 86 | default=QuestionDifficulty.UNDERGRAD.value, 87 | help="Difficulty level for generated questions (default: %(default)s)", 88 | ) 89 | 90 | parser.add_argument( 91 | "--no-verify", 92 | action="store_true", 93 | help="Skip solution verification step", 94 | ) 95 | 96 | parser.add_argument( 97 | "--verbose", 98 | "-v", 99 | action="store_true", 100 | help="Enable verbose logging", 101 | ) 102 | 103 | return parser 104 | 105 | 106 | async def process_pages( 107 | textbook_name: str, 108 | start_page: int, 109 | num_pages: int, 110 | textbook_manager: TextbookManager, 111 | generator: QuestionGenerator, 112 | difficulty: QuestionDifficulty, 113 | questions_per_chunk: int, 114 | verify: bool, 115 | ) -> list: 116 | """Process multiple consecutive pages and generate questions from their combined content.""" 117 | logger = logging.getLogger(__name__) 118 | 119 | combined_passage = None 120 | 121 | # Combine the content from multiple pages 122 | for page_num in range(start_page, start_page + num_pages): 123 | current_passage = textbook_manager.get_page(textbook_name, page_number=page_num) 124 | if current_passage: 125 | if combined_passage is None: 126 | combined_passage = current_passage 127 | else: 128 | combined_passage.content += "\n\n" + current_passage.content 129 | else: 130 | logger.warning(f"Could not find page {page_num} in {textbook_name}") 131 | 132 | if combined_passage is None: 133 | logger.warning(f"Could not load any pages starting from {start_page}") 134 | return [] 135 | 136 | logger.info(f"Processing pages {start_page}-{start_page + num_pages - 1}...") 137 | 138 | try: 139 | questions = await generator.generate_questions( 140 | combined_passage.content, 141 | num_questions=questions_per_chunk, 142 | difficulty=difficulty, 143 | verify=verify, 144 | src=f"{textbook_name}_pages_{start_page}-{start_page + num_pages - 1}", 145 | ) 146 | logger.info( 147 | f"✓ Completed pages {start_page}-{start_page + num_pages - 1} " 148 | f"- generated {len(questions)} questions" 149 | ) 150 | return questions 151 | except Exception as e: 152 | logger.error( 153 | f"Error processing pages {start_page}-{start_page + num_pages - 1}: {e}" 154 | ) 155 | return [] 156 | 157 | 158 | async def main() -> int: 159 | """Main async function to process textbooks and generate questions.""" 160 | parser = create_parser() 161 | args = parser.parse_args() 162 | 163 | setup_logging(args.verbose) 164 | logger = logging.getLogger(__name__) 165 | 166 | # Validate directories 167 | if not args.textbooks_dir.exists(): 168 | logger.error(f"Textbooks directory not found: {args.textbooks_dir}") 169 | return 1 170 | 171 | # Create output directory if it doesn't exist 172 | args.output_dir.mkdir(parents=True, exist_ok=True) 173 | 174 | # Initialize components 175 | try: 176 | textbook_manager = TextbookManager(str(args.textbooks_dir)) 177 | generator = QuestionGenerator( 178 | model_name=args.model, output_dir=str(args.output_dir) 179 | ) 180 | difficulty = QuestionDifficulty(args.difficulty) 181 | except Exception as e: 182 | logger.error(f"Failed to initialize components: {e}") 183 | return 1 184 | 185 | # Get available textbooks 186 | textbook_names = textbook_manager.get_all_textbook_names() 187 | if not textbook_names: 188 | logger.error("No textbooks found in the specified directory") 189 | return 1 190 | 191 | logger.info(f"Found {len(textbook_names)} textbooks: {textbook_names}") 192 | 193 | all_questions = [] 194 | verify = not args.no_verify 195 | 196 | for textbook_name in textbook_names: 197 | num_pages = textbook_manager.get_num_pages(textbook_name) 198 | if num_pages is None: 199 | logger.warning(f"Could not determine page count for {textbook_name}") 200 | continue 201 | 202 | logger.info(f"Processing textbook: {textbook_name} ({num_pages} pages)") 203 | 204 | textbook_questions = [] 205 | 206 | # Process pages in groups and batches 207 | for batch_start in range(0, num_pages, args.batch_size * args.pages_per_group): 208 | batch_end = min( 209 | batch_start + args.batch_size * args.pages_per_group, num_pages 210 | ) 211 | logger.info( 212 | f"Processing batch of pages {batch_start} to {batch_end - 1}..." 213 | ) 214 | 215 | tasks = [] 216 | for group_start in range(batch_start, batch_end, args.pages_per_group): 217 | if group_start < num_pages: 218 | actual_pages = min(args.pages_per_group, num_pages - group_start) 219 | tasks.append( 220 | process_pages( 221 | textbook_name, 222 | group_start, 223 | actual_pages, 224 | textbook_manager, 225 | generator, 226 | difficulty, 227 | args.questions_per_chunk, 228 | verify, 229 | ) 230 | ) 231 | 232 | batch_results = await asyncio.gather(*tasks, return_exceptions=True) 233 | 234 | for result in batch_results: 235 | if isinstance(result, Exception): 236 | logger.error(f"Batch processing error: {result}") 237 | else: 238 | textbook_questions.extend(result) 239 | 240 | logger.info( 241 | f"Batch complete. Questions for {textbook_name} so far: {len(textbook_questions)}" 242 | ) 243 | 244 | all_questions.extend(textbook_questions) 245 | logger.info( 246 | f"Finished processing {textbook_name}. Total questions: {len(textbook_questions)}" 247 | ) 248 | 249 | logger.info( 250 | f"Finished processing all textbooks. Total questions generated: {len(all_questions)}" 251 | ) 252 | return 0 253 | 254 | 255 | def cli_main() -> None: 256 | """Entry point for the CLI.""" 257 | try: 258 | exit_code = asyncio.run(main()) 259 | sys.exit(exit_code) 260 | except KeyboardInterrupt: 261 | print("\nOperation cancelled by user") 262 | sys.exit(1) 263 | except Exception as e: 264 | print(f"Unexpected error: {e}") 265 | sys.exit(1) 266 | 267 | 268 | if __name__ == "__main__": 269 | cli_main() 270 | -------------------------------------------------------------------------------- /src/tobyawesomeailibrary/get_dataset_question.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from functools import lru_cache 4 | 5 | from datasets import load_dataset 6 | 7 | 8 | @lru_cache(maxsize=1) 9 | def load_cached_dataset(): 10 | return load_dataset("lighteval/MATH", split="train") 11 | 12 | 13 | @lru_cache(maxsize=1) 14 | def load_cached_competition_dataset(): 15 | return load_dataset( 16 | "hendrycks/competition_math", split="train", trust_remote_code=True 17 | ) 18 | 19 | 20 | @lru_cache(maxsize=1) 21 | def load_cached_test_dataset(): 22 | return load_dataset("lighteval/MATH", split="test") 23 | 24 | 25 | @lru_cache(maxsize=1) 26 | def load_cached_numina_dataset(): 27 | return load_dataset("AI-MO/NuminaMath-CoT", split="train") 28 | 29 | 30 | @lru_cache(maxsize=1) 31 | def load_cached_skunkworks_dataset(): 32 | return load_dataset("SkunkworksAI/reasoning-0.01", split="train") 33 | 34 | 35 | @lru_cache(maxsize=1) 36 | def load_cached_mathinstruct_dataset(): 37 | return load_dataset("TIGER-Lab/MathInstruct", split="train") 38 | 39 | 40 | @lru_cache(maxsize=1) 41 | def load_cached_aslawliet_dataset(): 42 | return load_dataset("aslawliet/olympiads", split="train") 43 | 44 | 45 | def get_random_math_question(): 46 | # Load the MATH dataset from LightEval (cached) 47 | dataset = load_cached_dataset() 48 | 49 | # Get a random index 50 | random_index = random.randint(0, len(dataset) - 1) 51 | 52 | # Get the random question 53 | random_question = dataset[random_index] 54 | 55 | # Extract the problem and solution from the question 56 | problem = random_question["problem"] 57 | solution = random_question["solution"] 58 | 59 | return problem, solution 60 | 61 | 62 | def get_hard_math_question(): 63 | # Load the MATH dataset from LightEval (cached) 64 | dataset = load_cached_dataset() 65 | 66 | # Filter the dataset for Level 5 questions 67 | hard_questions = [q for q in dataset if q["level"] == "Level 5"] 68 | 69 | if not hard_questions: 70 | raise ValueError("No Level 5 questions found in the dataset") 71 | 72 | # Get a random hard question 73 | random_question = random.choice(hard_questions) 74 | 75 | # Extract the problem and solution from the question 76 | problem = random_question["problem"] 77 | solution = random_question["solution"] 78 | 79 | return problem, solution 80 | 81 | 82 | def get_medium_math_question(): 83 | # Load the MATH dataset from LightEval (cached) 84 | dataset = load_cached_dataset() 85 | 86 | # Filter the dataset for Level 5 questions 87 | hard_questions = [ 88 | q for q in dataset if q["level"] in ["Level 3", "Level 4", "Level 5"] 89 | ] 90 | 91 | if not hard_questions: 92 | raise ValueError("No Level 5 questions found in the dataset") 93 | 94 | # Get a random hard question 95 | random_question = random.choice(hard_questions) 96 | 97 | # Extract the problem and solution from the question 98 | problem = random_question["problem"] 99 | solution = random_question["solution"] 100 | 101 | return problem, solution 102 | 103 | 104 | def get_competition_math_problem(): 105 | # Load the competition math dataset (cached) 106 | dataset = load_cached_competition_dataset() 107 | 108 | # Get a random index 109 | random_index = random.randint(0, len(dataset) - 1) 110 | 111 | # Get the random question 112 | random_question = dataset[random_index] 113 | 114 | # Extract the problem and solution from the question 115 | problem = random_question["problem"] 116 | solution = random_question["solution"] 117 | 118 | return problem, solution 119 | 120 | 121 | def get_hard_competition_math_problem(): 122 | # Load the competition math dataset (cached) 123 | dataset = load_cached_competition_dataset() 124 | 125 | # Filter the dataset for Level 5 questions 126 | hard_questions = [q for q in dataset if q["level"] == "Level 5"] 127 | 128 | if not hard_questions: 129 | raise ValueError("No Level 5 questions found in the dataset") 130 | 131 | # Get a random hard question 132 | random_question = random.choice(hard_questions) 133 | 134 | # Extract the problem and solution from the question 135 | problem = random_question["problem"] 136 | solution = random_question["solution"] 137 | 138 | return problem, solution 139 | 140 | 141 | def get_gpqa_question(): 142 | # Load the GPQA questions from JSON file 143 | with open("lib/unique_gpqa_questions.json") as f: 144 | questions = json.load(f) 145 | 146 | # Get a random question 147 | random_question = random.choice(questions) 148 | 149 | # Extract the problem and solution from the question 150 | problem = random_question["question"] 151 | solution = random_question["correct_answer"] 152 | explanation = random_question["explanation"] 153 | returned_solution = f"{solution}\n\n{explanation}" 154 | return problem, returned_solution 155 | 156 | 157 | def get_test_math_question(): 158 | # Load the MATH test dataset (cached) 159 | dataset = load_cached_test_dataset() 160 | 161 | # Get a random index 162 | random_index = random.randint(0, len(dataset) - 1) 163 | 164 | # Get the random question 165 | random_question = dataset[random_index] 166 | 167 | # Extract the problem and solution from the question 168 | problem = random_question["problem"] 169 | solution = random_question["solution"] 170 | 171 | return problem, solution 172 | 173 | 174 | def get_test_gpqa_question(): 175 | # Load the GPQA test questions from JSON file 176 | with open("lib/unique_gpqa_test_questions.json") as f: 177 | questions = json.load(f) 178 | 179 | # Get a random question 180 | random_question = random.choice(questions) 181 | 182 | # Extract the problem and solution from the question 183 | problem = random_question["question"] 184 | solution = random_question["correct_answer"] 185 | explanation = random_question["explanation"] 186 | returned_solution = f"{solution}\n\n{explanation}" 187 | return problem, returned_solution 188 | 189 | 190 | def get_custom_question(): 191 | # Load the custom questions from JSON file 192 | try: 193 | with open("lib/custom_questions.json") as f: 194 | data = json.load(f) 195 | questions = data["questions"] 196 | problem = random.choice(questions) 197 | solution = "No solution provided. Please verify your answer independently." 198 | 199 | return problem, solution 200 | except FileNotFoundError: 201 | raise FileNotFoundError("custom_questions.json not found in lib directory") 202 | except KeyError: 203 | raise KeyError("custom_questions.json must contain a 'questions' array") 204 | except json.JSONDecodeError: 205 | raise ValueError("custom_questions.json is not valid JSON") 206 | 207 | 208 | def get_numina_math_question(): 209 | # Load the NuminaMath dataset (cached) 210 | dataset = load_cached_numina_dataset() 211 | 212 | # Get a random index 213 | random_index = random.randint(0, len(dataset) - 1) 214 | 215 | # Get the random question 216 | random_question = dataset[random_index] 217 | 218 | # Extract the problem and solution from the question 219 | problem = random_question["problem"] 220 | solution = random_question["solution"] 221 | 222 | return problem, solution 223 | 224 | 225 | def get_skunkworks_question(): 226 | # Load the Skunkworks dataset (cached) 227 | dataset = load_cached_skunkworks_dataset() 228 | 229 | # Get a random index 230 | random_index = random.randint(0, len(dataset) - 1) 231 | 232 | # Get the random question 233 | random_question = dataset[random_index] 234 | 235 | # Extract the problem and solution from the question 236 | problem = random_question["instruction"] 237 | solution = random_question["reasoning"] 238 | 239 | return problem, solution 240 | 241 | 242 | def get_mathinstruct_question(): 243 | # Load the MathInstruct dataset (cached) 244 | dataset = load_cached_mathinstruct_dataset() 245 | 246 | # Get a random index 247 | random_index = random.randint(0, len(dataset) - 1) 248 | 249 | # Get the random question 250 | random_question = dataset[random_index] 251 | 252 | # Extract the problem and solution from the question 253 | problem = random_question["instruction"] 254 | solution = random_question["output"] 255 | 256 | return problem, solution 257 | 258 | 259 | @lru_cache(maxsize=1) 260 | def get_cached_olympiad_questions(): 261 | dataset = load_cached_numina_dataset() 262 | return [q for q in dataset if q["source"] == "olympiads"] 263 | 264 | 265 | def get_numina_olympiad_question(count=1): 266 | # Use cached filtered questions 267 | olympiad_questions = get_cached_olympiad_questions() 268 | 269 | if not olympiad_questions: 270 | raise ValueError("No olympiad questions found in the dataset") 271 | 272 | # Get random olympiad questions 273 | if count == 1: 274 | random_question = random.choice(olympiad_questions) 275 | return random_question["problem"], random_question["solution"] 276 | else: 277 | sample_size = min(count, len(olympiad_questions)) 278 | random_questions = random.sample(olympiad_questions, sample_size) 279 | return [(q["problem"], q["solution"]) for q in random_questions] 280 | 281 | 282 | def get_aslawliet_olympiad_question(): 283 | # Load the aslawliet/olympiads dataset (cached) 284 | dataset = load_cached_aslawliet_dataset() 285 | 286 | # Get a random index 287 | random_index = random.randint(0, len(dataset) - 1) 288 | 289 | # Get the random question 290 | random_question = dataset[random_index] 291 | 292 | # Extract the problem and solution from the question 293 | # Note: Column names are 'problem' and 'has solution' 294 | problem = random_question["problem"] 295 | solution = random_question["solution"] 296 | 297 | return problem, solution 298 | 299 | 300 | # Example usage 301 | if __name__ == "__main__": 302 | random_problem, random_solution = get_random_math_question() 303 | print("Random MATH question:") 304 | print("Problem:", random_problem) 305 | print("\nSolution:", random_solution) 306 | 307 | print("\nHard MATH question:") 308 | hard_problem, hard_solution = get_hard_math_question() 309 | print("Problem:", hard_problem) 310 | print("\nSolution:", hard_solution) 311 | 312 | print("\nCompetition MATH question:") 313 | comp_problem, comp_solution = get_competition_math_problem() 314 | print("Problem:", comp_problem) 315 | print("\nSolution:", comp_solution) 316 | 317 | print("\nHard Competition MATH question:") 318 | hard_comp_problem, hard_comp_solution = get_hard_competition_math_problem() 319 | print("Problem:", hard_comp_problem) 320 | print("\nSolution:", hard_comp_solution) 321 | 322 | print("\nTest set MATH question:") 323 | test_problem, test_solution = get_test_math_question() 324 | print("Problem:", test_problem) 325 | print("\nSolution:", test_solution) 326 | 327 | print("\nGPQA test question:") 328 | gpqa_test_problem, gpqa_test_solution = get_test_gpqa_question() 329 | print("Problem:", gpqa_test_problem) 330 | print("\nSolution:", gpqa_test_solution) 331 | 332 | print("\nCustom question:") 333 | custom_problem, custom_solution = get_custom_question() 334 | print("Problem:", custom_problem) 335 | print("\nSolution:", custom_solution) 336 | 337 | print("\nNuminaMath question:") 338 | numina_problem, numina_solution = get_numina_math_question() 339 | print("Problem:", numina_problem) 340 | print("\nSolution:", numina_solution) 341 | 342 | print("\nSkunkworks reasoning question:") 343 | skunk_problem, skunk_solution = get_skunkworks_question() 344 | print("Problem:", skunk_problem) 345 | print("\nSolution:", skunk_solution) 346 | 347 | print("\nMathInstruct question:") 348 | math_inst_problem, math_inst_solution = get_mathinstruct_question() 349 | print("Problem:", math_inst_problem) 350 | print("\nSolution:", math_inst_solution) 351 | 352 | print("\nNumina Math Olympiad question:") 353 | olympiad_problem, olympiad_solution = get_numina_olympiad_question() 354 | print("Problem:", olympiad_problem) 355 | print("\nSolution:", olympiad_solution) 356 | 357 | print("\nAslawliet Olympiad question:") 358 | aslawliet_problem, aslawliet_solution = get_aslawliet_olympiad_question() 359 | print("Problem:", aslawliet_problem) 360 | print("\nSolution:", aslawliet_solution) 361 | -------------------------------------------------------------------------------- /src/textbooks_to_rl/question_generator/generator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | import unittest 7 | from datetime import datetime 8 | from typing import List, Optional, Union 9 | 10 | from sympy import simplify, sympify 11 | from sympy.parsing.latex import parse_latex 12 | 13 | # Adjust the path so that the inference/evaluation modules can be imported. 14 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) 15 | 16 | try: 17 | from src.tobyawesomeailibrary.eval_response import evaluate_text 18 | from src.tobyawesomeailibrary.inference import generate_text 19 | except ImportError: 20 | # Fallback for when running as installed package 21 | import os 22 | import sys 23 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) 24 | from src.tobyawesomeailibrary.eval_response import evaluate_text 25 | from src.tobyawesomeailibrary.inference import generate_text 26 | 27 | from .models import QuestionAnswer, QuestionDifficulty, QuestionDomain, ValidationResult 28 | from .parsers import ResponseParser 29 | from .prompt_templates import QuestionPromptTemplates 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | # ----------------------------------------------------------------------------- 34 | # Utility Functions for LaTeX Processing and Math Equivalence Checking 35 | # ----------------------------------------------------------------------------- 36 | 37 | 38 | def extract_boxed_expression(latex_str: str) -> str: 39 | """ 40 | Extract the expression inside the \boxed{...} command, handling nested braces. 41 | If not found, returns the entire string. 42 | """ 43 | latex_str = latex_str.strip() 44 | 45 | if "\\boxed{" not in latex_str: 46 | return latex_str.strip() 47 | 48 | start_idx = latex_str.index("\\boxed{") + 7 # len("\\boxed{") 49 | brace_count = 1 50 | end_idx = start_idx 51 | 52 | while brace_count > 0 and end_idx < len(latex_str): 53 | if latex_str[end_idx] == "{": 54 | brace_count += 1 55 | elif latex_str[end_idx] == "}": 56 | brace_count -= 1 57 | end_idx += 1 58 | 59 | if brace_count == 0: 60 | return latex_str[start_idx : end_idx - 1].strip() 61 | return latex_str.strip() 62 | 63 | 64 | def _parse_math_expr(expr_str: str) -> Union[object, None]: 65 | """ 66 | Try parsing the math expression using sympy's LaTeX parser. 67 | If that fails (e.g. antlr4 isn't installed), fall back to sympy.sympify 68 | after replacing '^' with '**' for exponentiation. 69 | """ 70 | try: 71 | return parse_latex(expr_str) 72 | except Exception: 73 | expr_str_modified = expr_str.replace("^", "**") 74 | return sympify(expr_str_modified) 75 | 76 | 77 | def check_math_equivalence(expected_latex: str, verification_latex: str) -> bool: 78 | """ 79 | Check if two LaTeX math expressions are mathematically equivalent. 80 | First, ensure both expressions have a \boxed{...} wrapper. 81 | Then extract the inner expression and attempt to parse using Sympy. 82 | If parsing fails, fall back to a normalized string comparison. 83 | """ 84 | has_boxed_expected = "\\boxed{" in expected_latex 85 | has_boxed_verification = "\\boxed{" in verification_latex 86 | if has_boxed_expected != has_boxed_verification: 87 | return False 88 | 89 | # Extract boxed expressions first 90 | expected_inner = extract_boxed_expression(expected_latex) 91 | verification_inner = extract_boxed_expression(verification_latex) 92 | 93 | # Handle equals signs by taking the right-most part 94 | if "=" in expected_inner: 95 | expected_inner = expected_inner.split("=")[-1].strip() 96 | if "=" in verification_inner: 97 | verification_inner = verification_inner.split("=")[-1].strip() 98 | 99 | try: 100 | expected_expr = _parse_math_expr(expected_inner) 101 | verification_expr = _parse_math_expr(verification_inner) 102 | except Exception: 103 | normalized_expected = expected_inner.replace(" ", "").lower() 104 | normalized_verification = verification_inner.replace(" ", "").lower() 105 | return normalized_expected == normalized_verification 106 | 107 | difference = simplify(expected_expr - verification_expr) 108 | return difference == 0 109 | 110 | 111 | # ----------------------------------------------------------------------------- 112 | # Main Class for Generating and Managing Questions 113 | # ----------------------------------------------------------------------------- 114 | 115 | 116 | class QuestionGenerator: 117 | """Main class for generating and managing questions.""" 118 | 119 | def __init__( 120 | self, 121 | model_name: str = "gpt-4o-mini", 122 | output_dir: str = "generated_questions", 123 | verification_model: Optional[str] = None, 124 | ) -> None: 125 | self.model_name = model_name 126 | self.verification_model = verification_model or model_name 127 | self.parser = ResponseParser() 128 | self.templates = QuestionPromptTemplates() 129 | self.output_dir = output_dir 130 | os.makedirs(self.output_dir, exist_ok=True) 131 | 132 | async def generate_questions( 133 | self, 134 | passage: str, 135 | num_questions: int = 3, 136 | difficulty: QuestionDifficulty = QuestionDifficulty.UNDERGRAD, 137 | verify: bool = True, 138 | verification_threshold: float = 0.8, 139 | src: Optional[str] = None, 140 | save_json: bool = True, 141 | add_hints: bool = False, 142 | classify_domain: bool = False, 143 | verification_model: Optional[str] = None, 144 | ) -> List[QuestionAnswer]: 145 | """Generate questions from a passage and return structured data.""" 146 | current_verification_model = verification_model or self.verification_model 147 | 148 | # Generate questions using the prompt template. 149 | prompt = self.templates.question_generation(passage, num_questions, difficulty) 150 | response = await generate_text(model=self.model_name, prompt=prompt) 151 | 152 | qa_pairs = self.parser.extract_qa_pairs(response) 153 | logger.info( 154 | f"Generated {len(qa_pairs)} initial questions at {difficulty.value} level" 155 | ) 156 | 157 | verified_pairs = [] 158 | for qa in qa_pairs: 159 | if verify: 160 | is_valid = await self._verify_question_solution( 161 | qa, 162 | passage, 163 | verification_threshold, 164 | current_verification_model, 165 | verification_attempts=3, 166 | ) 167 | qa.is_valid = is_valid 168 | if is_valid: 169 | if add_hints: 170 | await self._add_hints(qa, difficulty) 171 | if classify_domain: 172 | await self._classify_domain(qa) 173 | if src: 174 | qa.source = src 175 | verified_pairs.append(qa) 176 | else: 177 | qa.is_valid = None 178 | if add_hints: 179 | await self._add_hints(qa, difficulty) 180 | if classify_domain: 181 | await self._classify_domain(qa) 182 | if src: 183 | qa.source = src 184 | verified_pairs.append(qa) 185 | 186 | # Save questions to JSON files. 187 | if save_json: 188 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 189 | diff_serializable = ( 190 | difficulty.value if hasattr(difficulty, "value") else difficulty 191 | ) 192 | 193 | # Calculate and print validation percentage 194 | if verify: 195 | valid_count = len(verified_pairs) 196 | total_count = len(qa_pairs) 197 | valid_percentage = ( 198 | (valid_count / total_count) * 100 if total_count > 0 else 0 199 | ) 200 | logger.info( 201 | f"Valid questions: {valid_count}/{total_count} ({valid_percentage:.1f}%)" 202 | ) 203 | 204 | for i, qa in enumerate(verified_pairs): 205 | boxed_solution = None 206 | if "\\boxed{" in qa.solution: 207 | try: 208 | last_start = qa.solution.rindex("\\boxed{") + 7 209 | brace_count = 1 210 | last_end = last_start 211 | while brace_count > 0 and last_end < len(qa.solution): 212 | if qa.solution[last_end] == "{": 213 | brace_count += 1 214 | elif qa.solution[last_end] == "}": 215 | brace_count -= 1 216 | last_end += 1 217 | boxed_solution = qa.solution[last_start : last_end - 1] 218 | except ValueError: 219 | pass 220 | 221 | question_dict = { 222 | "source": qa.source, 223 | "question": qa.question, 224 | "solution": qa.solution, 225 | "hints": qa.hints, 226 | "difficulty": diff_serializable, 227 | "domain": qa.domain.value if qa.domain else None, 228 | "timestamp": timestamp, 229 | "difficulty_description": difficulty.get_description(), 230 | "model": self.model_name, 231 | "boxed_solution": boxed_solution, 232 | "validated": qa.is_valid, 233 | } 234 | 235 | filename = f"question_{timestamp}_{i + 1}.json" 236 | filepath = os.path.join(self.output_dir, filename) 237 | with open(filepath, "w") as f: 238 | json.dump(question_dict, f, indent=4) 239 | 240 | return verified_pairs 241 | 242 | async def _classify_domain(self, qa: QuestionAnswer) -> None: 243 | """Classify the domain of a question.""" 244 | domains = [d.value for d in QuestionDomain] 245 | domains_str = ", ".join(domains) 246 | classification_prompt = f""" 247 | Classify this question into exactly one of these domains: {domains_str} 248 | 249 | Question: {qa.question} 250 | Solution: {qa.solution} 251 | 252 | Respond with ONLY the domain name from the list above that best matches. 253 | Do not include any other text in your response. 254 | """ 255 | domain_str = await generate_text( 256 | model=self.model_name, prompt=classification_prompt 257 | ) 258 | domain_str = domain_str.strip().lower() 259 | try: 260 | qa.domain = QuestionDomain(domain_str) 261 | except ValueError: 262 | logger.warning(f"Invalid domain '{domain_str}', defaulting to OTHER") 263 | qa.domain = QuestionDomain.OTHER 264 | 265 | async def _verify_question_solution( 266 | self, 267 | qa: QuestionAnswer, 268 | source_text: str, 269 | threshold: float, 270 | verification_model: str, 271 | verification_attempts: int = 3, 272 | ) -> bool: 273 | """ 274 | Verify a question's solution with multiple concurrent attempts. 275 | Uses the improved math equivalence check to compare solutions. 276 | """ 277 | verification_prompt = f""" 278 | Reference text: 279 | {source_text} 280 | 281 | Your answer must be precise and include a final answer in a \\boxed{{...}} format. 282 | 283 | Question: 284 | {qa.question} 285 | 286 | Solve the problem step by step. 287 | """ 288 | verification_tasks = [ 289 | generate_text(model=verification_model, prompt=verification_prompt) 290 | for _ in range(verification_attempts) 291 | ] 292 | model_solutions = await asyncio.gather(*verification_tasks) 293 | original_answer = qa.solution 294 | 295 | logger.debug("=== Solution Comparison ===") 296 | logger.debug(f"Original boxed: {extract_boxed_expression(original_answer)}") 297 | 298 | for i, model_solution in enumerate(model_solutions, 1): 299 | is_equivalent = check_math_equivalence(original_answer, model_solution) 300 | logger.debug( 301 | f"Verification {i} boxed: {extract_boxed_expression(model_solution)}" 302 | ) 303 | logger.debug(f"Equivalent: {is_equivalent}") 304 | 305 | if is_equivalent: 306 | return True 307 | return False 308 | 309 | async def _add_hints( 310 | self, 311 | qa: QuestionAnswer, 312 | difficulty: QuestionDifficulty = QuestionDifficulty.UNDERGRAD, 313 | ) -> None: 314 | """Add hints to a QuestionAnswer object.""" 315 | prompt = self.templates.hint_generation(qa.question, difficulty) 316 | hints_text = await generate_text(model=self.model_name, prompt=prompt) 317 | qa.hints = self.parser.extract_hints(hints_text) 318 | 319 | async def validate_solution( 320 | self, question: str, student_solution: str, correct_solution: str 321 | ) -> ValidationResult: 322 | """Validate a student's solution against the correct solution.""" 323 | eval_result = await evaluate_text( 324 | self.model_name, student_solution, correct_solution 325 | ) 326 | boxed_solution = None 327 | if "\\boxed{" in correct_solution: 328 | try: 329 | last_start = correct_solution.rindex("\\boxed{") + 7 330 | last_end = correct_solution.index("}", last_start) 331 | boxed_solution = correct_solution[last_start:last_end] 332 | except ValueError: 333 | pass 334 | 335 | return ValidationResult( 336 | is_correct=eval_result[0] == 1, 337 | score=float(eval_result[0]), 338 | feedback=None, 339 | boxed_solution=boxed_solution, 340 | ) 341 | 342 | 343 | # ----------------------------------------------------------------------------- 344 | # Unit Tests for the Math Equivalence Checker 345 | # ----------------------------------------------------------------------------- 346 | 347 | 348 | class TestMathEquivalence(unittest.TestCase): 349 | def test_equivalent_simple(self): 350 | expr1 = r"\boxed{1+2}" 351 | expr2 = r"\boxed{3}" 352 | self.assertTrue(check_math_equivalence(expr1, expr2)) 353 | 354 | def test_equivalent_with_spaces(self): 355 | expr1 = r"\boxed{1 + 2}" 356 | expr2 = r"\boxed{ 3 }" 357 | self.assertTrue(check_math_equivalence(expr1, expr2)) 358 | 359 | def test_non_equivalent(self): 360 | expr1 = r"\boxed{1+2}" 361 | expr2 = r"\boxed{4}" 362 | self.assertFalse(check_math_equivalence(expr1, expr2)) 363 | 364 | def test_equivalent_complex(self): 365 | expr1 = r"\boxed{(x+1)^2}" 366 | expr2 = r"\boxed{x^2+2*x+1}" 367 | self.assertTrue(check_math_equivalence(expr1, expr2)) 368 | 369 | def test_malformed_latex(self): 370 | # If one expression lacks \boxed{}, they should not be equivalent. 371 | expr1 = "1+2" 372 | expr2 = r"\boxed{3}" 373 | self.assertFalse(check_math_equivalence(expr1, expr2)) 374 | 375 | def test_equivalent_with_equals(self): 376 | expr1 = r"\boxed{x^2 + 1}" 377 | expr2 = r"\boxed{f(x) = x^2 + 1}" 378 | self.assertTrue(check_math_equivalence(expr1, expr2)) 379 | 380 | def test_equivalent_multiple_equals(self): 381 | expr1 = r"\boxed{x = y = 1}" 382 | expr2 = r"\boxed{1}" 383 | self.assertTrue(check_math_equivalence(expr1, expr2)) 384 | 385 | 386 | # ----------------------------------------------------------------------------- 387 | # Run Unit Tests if executed as main 388 | # ----------------------------------------------------------------------------- 389 | 390 | if __name__ == "__main__": 391 | unittest.main() 392 | --------------------------------------------------------------------------------