├── tests ├── __init__.py ├── e2e_tests │ ├── __init__.py │ ├── test_generate_cot.py │ ├── test_generate_vqa.py │ ├── test_generate_atomic.py │ ├── test_generate_aggregated.py │ ├── test_generate_multi_hop.py │ └── conftest.py └── integration_tests │ ├── __init__.py │ ├── models │ ├── splitter │ │ ├── test_character_splitter.py │ │ ├── test_markdown_splitter.py │ │ └── test_recursive_character_splitter.py │ ├── reader │ │ └── test_mineru_parser.py │ ├── llm │ │ └── local │ │ │ └── test_hf_wrapper.py │ └── partitioner │ │ └── test_bfs_partitioner.py │ └── test_engine.py ├── webui ├── __init__.py ├── utils │ ├── __init__.py │ ├── cache.py │ ├── preview_file.py │ └── count_tokens.py ├── test_api.py ├── base.py └── examples │ ├── txt_demo.txt │ ├── csv_demo.csv │ ├── jsonl_demo.jsonl │ ├── json_demo.json │ └── vqa_demo.json ├── baselines ├── __init__.py ├── BDS │ ├── __init__.py │ └── README.md ├── Genie │ ├── __init__.py │ └── README.md ├── Wrap │ ├── __init__.py │ └── README.md ├── EntiGraph │ ├── __init__.py │ ├── tasks │ │ ├── __init.py │ │ ├── task_abc.py │ │ └── baseline_task.py │ ├── README.md │ └── inference │ │ └── devapi.py ├── LongForm │ ├── __init__.py │ └── README.md └── SELF-QA │ ├── __init__.py │ └── README.md ├── graphgen ├── __init__.py ├── models │ ├── vis │ │ ├── __init__.py │ │ └── community_visualizer.py │ ├── llm │ │ ├── api │ │ │ └── __init__.py │ │ ├── local │ │ │ ├── __init__.py │ │ │ ├── trt_wrapper.py │ │ │ └── tgi_wrapper.py │ │ └── __init__.py │ ├── searcher │ │ ├── __init__.py │ │ ├── db │ │ │ └── __init__.py │ │ ├── kg │ │ │ ├── __init__.py │ │ │ └── wiki_search.py │ │ └── web │ │ │ ├── __init__.py │ │ │ ├── bing_search.py │ │ │ └── google_search.py │ ├── storage │ │ ├── kv │ │ │ ├── __init__.py │ │ │ ├── json_storage.py │ │ │ └── rocksdb_storage.py │ │ ├── graph │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── rocksdb_cache.py │ ├── extractor │ │ ├── key_information_extractor.py │ │ └── __init__.py │ ├── kg_builder │ │ └── __init__.py │ ├── splitter │ │ ├── __init__.py │ │ ├── character_splitter.py │ │ └── markdown_splitter.py │ ├── evaluator │ │ ├── __init__.py │ │ ├── length_evaluator.py │ │ ├── base_evaluator.py │ │ └── mtld_evaluator.py │ ├── partitioner │ │ ├── __init__.py │ │ ├── dfs_partitioner.py │ │ └── bfs_partitioner.py │ ├── reader │ │ ├── __init__.py │ │ ├── csv_reader.py │ │ ├── txt_reader.py │ │ ├── parquet_reader.py │ │ └── json_reader.py │ ├── generator │ │ ├── __init__.py │ │ ├── atomic_generator.py │ │ ├── multi_hop_generator.py │ │ └── quiz_generator.py │ ├── tokenizer │ │ ├── tiktoken_tokenizer.py │ │ ├── hf_tokenizer.py │ │ └── __init__.py │ └── __init__.py ├── operators │ ├── evaluate │ │ └── __init__.py │ ├── preprocess │ │ ├── __init__.py │ │ └── resolute_coreference.py │ ├── read │ │ └── __init__.py │ ├── quiz │ │ └── __init__.py │ ├── search │ │ └── __init__.py │ ├── chunk │ │ └── __init__.py │ ├── judge │ │ └── __init__.py │ ├── build_kg │ │ ├── __init__.py │ │ ├── build_text_kg.py │ │ ├── build_mm_kg.py │ │ └── build_kg_service.py │ ├── extract │ │ ├── __init__.py │ │ └── extract_service.py │ ├── generate │ │ ├── __init__.py │ │ └── generate_service.py │ ├── partition │ │ └── __init__.py │ └── __init__.py ├── common │ └── __init__.py ├── templates │ ├── extraction │ │ ├── __init__.py │ │ ├── schemas │ │ │ └── legal_contract.json │ │ └── schema_guided_extraction.py │ ├── kg │ │ ├── __init__.py │ │ └── kg_summarization.py │ ├── generation │ │ ├── __init__.py │ │ ├── atomic_generation.py │ │ └── multi_hop_generation.py │ ├── statement_judgement.py │ ├── __init__.py │ ├── coreference_resolution.py │ ├── question_generation.py │ └── search_judgement.py ├── utils │ ├── wrap.py │ ├── __init__.py │ ├── hash.py │ ├── loop.py │ ├── detect_lang.py │ ├── run_concurrent.py │ ├── help_nltk.py │ └── device.py ├── bases │ ├── __init__.py │ ├── base_searcher.py │ ├── base_extractor.py │ ├── base_tokenizer.py │ ├── base_kg_builder.py │ ├── base_operator.py │ ├── base_partitioner.py │ └── base_llm_wrapper.py └── _version.py ├── requirements-dev.txt ├── examples ├── generate │ ├── generate_vqa │ │ ├── README.md │ │ ├── generate_vqa.sh │ │ └── vqa_config.yaml │ ├── generate_cot_qa │ │ ├── README.md │ │ ├── generate_cot.sh │ │ └── cot_config.yaml │ ├── generate_multi_hop_qa │ │ ├── README.md │ │ ├── generate_multi_hop.sh │ │ └── multi_hop_config.yaml │ ├── generate_atomic_qa │ │ ├── generate_atomic.sh │ │ ├── README.md │ │ └── atomic_config.yaml │ └── generate_aggregated_qa │ │ ├── generate_aggregated.sh │ │ ├── README.md │ │ └── aggregated_config.yaml ├── extract │ └── extract_schema_guided │ │ ├── README.md │ │ ├── extract_schema_guided.sh │ │ └── schema_guided_extraction_config.yaml ├── input_examples │ ├── pdf_demo.pdf │ ├── images │ │ ├── 0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg │ │ ├── 390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg │ │ ├── 4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg │ │ ├── 8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg │ │ ├── cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg │ │ └── eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg │ ├── search_rna_demo.jsonl │ ├── search_dna_demo.jsonl │ ├── search_protein_demo.jsonl │ ├── extract_demo.txt │ ├── txt_demo.txt │ ├── csv_demo.csv │ ├── jsonl_demo.jsonl │ └── json_demo.json ├── search │ ├── search_dna.sh │ ├── search_rna.sh │ ├── search_uniprot.sh │ ├── search_rna_config.yaml │ ├── search_protein_config.yaml │ ├── search_dna_config.yaml │ └── build_db │ │ └── build_protein_blast_db.sh ├── baselines │ ├── generate_bds.sh │ ├── generate_wrap.sh │ ├── generate_genie.sh │ ├── generate_selfqa.sh │ ├── generate_longform.sh │ ├── generate_entigraph.sh │ └── generate_all_baselines.sh ├── evaluate │ └── evaluate.sh └── output_examples │ └── atomic_alpaca.json ├── uv.lock ├── resources ├── images │ ├── flow.png │ └── logo.png ├── nltk_data │ ├── tokenizers │ │ └── punkt_tab │ │ │ └── english │ │ │ ├── sent_starters.txt │ │ │ ├── collocations.tab │ │ │ └── abbrev_types.txt │ └── corpora │ │ └── stopwords │ │ └── english └── README.md ├── MANIFEST.in ├── pyproject.toml ├── .pre-commit-config.yaml ├── .github ├── sync-config.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── workflows │ ├── pylint.yml │ ├── pypi.yml │ ├── push-to-hf.yml │ ├── push-to-ms.yml │ └── sync-demo.yml └── contributing.md ├── requirements.txt ├── CITATION.cff ├── Dockerfile ├── .env.example └── setup.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /webui/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/BDS/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/Genie/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/Wrap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /tests/e2e_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/EntiGraph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/LongForm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/SELF-QA/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/vis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /baselines/EntiGraph/tasks/__init.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/llm/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/llm/local/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/searcher/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/storage/kv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration_tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/searcher/db/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/searcher/kg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/searcher/web/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/models/storage/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/operators/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphgen/operators/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/generate/generate_vqa/README.md: -------------------------------------------------------------------------------- 1 | # Generate VQAs -------------------------------------------------------------------------------- /graphgen/operators/read/__init__.py: -------------------------------------------------------------------------------- 1 | from .read import read 2 | -------------------------------------------------------------------------------- /baselines/Wrap/README.md: -------------------------------------------------------------------------------- 1 | # Wrap 2 | https://arxiv.org/abs/2401.16380 -------------------------------------------------------------------------------- /examples/generate/generate_cot_qa/README.md: -------------------------------------------------------------------------------- 1 | # Generate CoT QAs 2 | -------------------------------------------------------------------------------- /baselines/Genie/README.md: -------------------------------------------------------------------------------- 1 | # Genie 2 | https://arxiv.org/pdf/2401.14367 -------------------------------------------------------------------------------- /baselines/SELF-QA/README.md: -------------------------------------------------------------------------------- 1 | # SELF-QA 2 | https://arxiv.org/abs/2305.11952 -------------------------------------------------------------------------------- /uv.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | revision = 3 3 | requires-python = ">=3.10" 4 | -------------------------------------------------------------------------------- /examples/generate/generate_multi_hop_qa/README.md: -------------------------------------------------------------------------------- 1 | # Generate Multi-hop QAs 2 | -------------------------------------------------------------------------------- /graphgen/models/extractor/key_information_extractor.py: -------------------------------------------------------------------------------- 1 | # TODO: text2json 2 | -------------------------------------------------------------------------------- /graphgen/operators/quiz/__init__.py: -------------------------------------------------------------------------------- 1 | from .quiz_service import QuizService 2 | -------------------------------------------------------------------------------- /graphgen/operators/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .search_all import search_all 2 | -------------------------------------------------------------------------------- /graphgen/operators/chunk/__init__.py: -------------------------------------------------------------------------------- 1 | from .chunk_service import ChunkService 2 | -------------------------------------------------------------------------------- /graphgen/operators/judge/__init__.py: -------------------------------------------------------------------------------- 1 | from .judge_service import JudgeService 2 | -------------------------------------------------------------------------------- /graphgen/operators/build_kg/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_kg_service import BuildKGService 2 | -------------------------------------------------------------------------------- /graphgen/operators/extract/__init__.py: -------------------------------------------------------------------------------- 1 | from .extract_service import ExtractService 2 | -------------------------------------------------------------------------------- /graphgen/operators/generate/__init__.py: -------------------------------------------------------------------------------- 1 | from .generate_service import GenerateService 2 | -------------------------------------------------------------------------------- /graphgen/operators/partition/__init__.py: -------------------------------------------------------------------------------- 1 | from .partition_service import PartitionService 2 | -------------------------------------------------------------------------------- /examples/extract/extract_schema_guided/README.md: -------------------------------------------------------------------------------- 1 | # Extract Schema-Guided Information from Documents 2 | -------------------------------------------------------------------------------- /graphgen/models/extractor/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema_guided_extractor import SchemaGuidedExtractor 2 | -------------------------------------------------------------------------------- /graphgen/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .init_llm import init_llm 2 | from .init_storage import init_storage 3 | -------------------------------------------------------------------------------- /resources/images/flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/resources/images/flow.png -------------------------------------------------------------------------------- /resources/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/resources/images/logo.png -------------------------------------------------------------------------------- /graphgen/templates/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema_guided_extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include graphgen/configs/*.yaml 3 | recursive-include graphgen/configs *.yaml 4 | -------------------------------------------------------------------------------- /baselines/LongForm/README.md: -------------------------------------------------------------------------------- 1 | # LongForm 2 | https://arxiv.org/pdf/2304.08460 3 | https://github.com/akoksal/LongForm/tree/main -------------------------------------------------------------------------------- /examples/input_examples/pdf_demo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/pdf_demo.pdf -------------------------------------------------------------------------------- /graphgen/models/kg_builder/__init__.py: -------------------------------------------------------------------------------- 1 | from .light_rag_kg_builder import LightRAGKGBuilder 2 | from .mm_kg_builder import MMKGBuilder 3 | -------------------------------------------------------------------------------- /baselines/EntiGraph/README.md: -------------------------------------------------------------------------------- 1 | # EntiGraph 2 | https://arxiv.org/abs/2409.07431 3 | https://github.com/zitongyang/synthetic_continued_pretraining -------------------------------------------------------------------------------- /examples/search/search_dna.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file graphgen/configs/search_dna_config.yaml \ 3 | --output_dir cache/ 4 | 5 | -------------------------------------------------------------------------------- /examples/search/search_rna.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file graphgen/configs/search_rna_config.yaml \ 3 | --output_dir cache/ 4 | 5 | -------------------------------------------------------------------------------- /baselines/BDS/README.md: -------------------------------------------------------------------------------- 1 | # Bottom-up Domain-specific Superintelligence: A Reliable Knowledge Graph is What We Need 2 | https://arxiv.org/pdf/2507.13966 -------------------------------------------------------------------------------- /examples/search/search_uniprot.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file graphgen/configs/search_protein_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /examples/generate/generate_vqa/generate_vqa.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/generate/generate_vqa/vqa_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /graphgen/models/splitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .recursive_character_splitter import ( 2 | ChineseRecursiveTextSplitter, 3 | RecursiveCharacterSplitter, 4 | ) 5 | -------------------------------------------------------------------------------- /webui/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .cache import cleanup_workspace, setup_workspace 2 | from .count_tokens import count_tokens 3 | from .preview_file import preview_file 4 | -------------------------------------------------------------------------------- /examples/generate/generate_cot_qa/generate_cot.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/generate/generate_cot_qa/cot_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /examples/generate/generate_atomic_qa/generate_atomic.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/generate/generate_atomic_qa/atomic_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /examples/generate/generate_multi_hop_qa/generate_multi_hop.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/generate/generate_multi_hop_qa/multi_hop_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /examples/baselines/generate_bds.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.BDS.bds --input_file resources/input_examples/graphml_demo.graphml \ 2 | --output_file cache/data/bds.json \ 3 | -------------------------------------------------------------------------------- /examples/generate/generate_aggregated_qa/generate_aggregated.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/generate/generate_aggregated_qa/aggregated_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /graphgen/templates/kg/__init__.py: -------------------------------------------------------------------------------- 1 | from .kg_extraction import KG_EXTRACTION_PROMPT 2 | from .kg_summarization import KG_SUMMARIZATION_PROMPT 3 | from .mm_kg_extraction import MMKG_EXTRACTION_PROMPT 4 | -------------------------------------------------------------------------------- /examples/extract/extract_schema_guided/extract_schema_guided.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.run \ 2 | --config_file examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml \ 3 | --output_dir cache/ 4 | -------------------------------------------------------------------------------- /examples/generate/generate_atomic_qa/README.md: -------------------------------------------------------------------------------- 1 | # Generate Atomic QAs 2 | 3 | Atomic mode generates question-answer pairs that test basic, isolated knowledge from individual facts or relationships in the knowledge graph. -------------------------------------------------------------------------------- /tests/e2e_tests/test_generate_cot.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .conftest import run_generate_test 4 | 5 | 6 | def test_generate_cot(tmp_path: Path): 7 | run_generate_test(tmp_path, "cot_config.yaml") 8 | -------------------------------------------------------------------------------- /tests/e2e_tests/test_generate_vqa.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .conftest import run_generate_test 4 | 5 | 6 | def test_generate_vqa(tmp_path: Path): 7 | run_generate_test(tmp_path, "vqa_config.yaml") 8 | -------------------------------------------------------------------------------- /graphgen/models/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .length_evaluator import LengthEvaluator 2 | from .mtld_evaluator import MTLDEvaluator 3 | from .reward_evaluator import RewardEvaluator 4 | from .uni_evaluator import UniEvaluator 5 | -------------------------------------------------------------------------------- /graphgen/models/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .api.http_client import HTTPClient 2 | from .api.ollama_client import OllamaClient 3 | from .api.openai_client import OpenAIClient 4 | from .local.hf_wrapper import HuggingFaceWrapper 5 | -------------------------------------------------------------------------------- /tests/e2e_tests/test_generate_atomic.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .conftest import run_generate_test 4 | 5 | 6 | def test_generate_atomic(tmp_path: Path): 7 | run_generate_test(tmp_path, "atomic_config.yaml") 8 | -------------------------------------------------------------------------------- /tests/e2e_tests/test_generate_aggregated.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .conftest import run_generate_test 4 | 5 | 6 | def test_generate_aggregated(tmp_path: Path): 7 | run_generate_test(tmp_path, "aggregated_config.yaml") 8 | -------------------------------------------------------------------------------- /tests/e2e_tests/test_generate_multi_hop.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from .conftest import run_generate_test 4 | 5 | 6 | def test_generate_multi_hop(tmp_path: Path): 7 | run_generate_test(tmp_path, "multi_hop_config.yaml") 8 | -------------------------------------------------------------------------------- /examples/baselines/generate_wrap.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.Wrap.wrap --input_file resources/input_examples/raw_demo.jsonl \ 2 | --data_type raw \ 3 | --output_file cache/data/wrap.json \ 4 | -------------------------------------------------------------------------------- /examples/baselines/generate_genie.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.Genie.genie --input_file resources/input_examples/raw_demo.jsonl \ 2 | --data_type raw \ 3 | --output_file cache/data/genie.json \ 4 | -------------------------------------------------------------------------------- /examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg -------------------------------------------------------------------------------- /examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg -------------------------------------------------------------------------------- /examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg -------------------------------------------------------------------------------- /examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg -------------------------------------------------------------------------------- /examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg -------------------------------------------------------------------------------- /examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternScience/GraphGen/HEAD/examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg -------------------------------------------------------------------------------- /examples/baselines/generate_selfqa.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.SELF-QA.self-qa --input_file resources/input_examples/raw_demo.jsonl \ 2 | --data_type raw \ 3 | --output_file cache/data/self-qa.json \ 4 | -------------------------------------------------------------------------------- /examples/baselines/generate_longform.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.LongForm.longform --input_file resources/input_examples/raw_demo.jsonl \ 2 | --data_type raw \ 3 | --output_file cache/data/longform.json \ 4 | -------------------------------------------------------------------------------- /examples/baselines/generate_entigraph.sh: -------------------------------------------------------------------------------- 1 | python3 -m baselines.EntiGraph.entigraph --input_file resources/input_examples/raw_demo.jsonl \ 2 | --data_type raw \ 3 | --output_file cache/data/entigraph.json \ 4 | -------------------------------------------------------------------------------- /graphgen/models/partitioner/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_bfs_partitioner import AnchorBFSPartitioner 2 | from .bfs_partitioner import BFSPartitioner 3 | from .dfs_partitioner import DFSPartitioner 4 | from .ece_partitioner import ECEPartitioner 5 | from .leiden_partitioner import LeidenPartitioner 6 | -------------------------------------------------------------------------------- /examples/evaluate/evaluate.sh: -------------------------------------------------------------------------------- 1 | python3 -m graphgen.evaluate --folder cache/data \ 2 | --output cache/output \ 3 | --reward "OpenAssistant/reward-model-deberta-v3-large-v2,BAAI/IndustryCorpus2_DataRater" \ 4 | --uni MingZhong/unieval-sum \ 5 | -------------------------------------------------------------------------------- /examples/generate/generate_aggregated_qa/README.md: -------------------------------------------------------------------------------- 1 | # Generate Aggregated QAs 2 | 3 | Aggregated mode is one of three question-answering scenarios in GraphGen (alongside atomic and multi-hop) designed to generate synthetic training data that incorporates complex, integrated knowledge from multiple sources. -------------------------------------------------------------------------------- /examples/baselines/generate_all_baselines.sh: -------------------------------------------------------------------------------- 1 | # generate all baselines at one go 2 | 3 | bash scripts/baselines/generate_wrap.sh 4 | bash scripts/baselines/generate_selfqa.sh 5 | bash scripts/baselines/generate_longform.sh 6 | bash scripts/baselines/generate_genie.sh 7 | bash scripts/baselines/generate_entigraph.sh -------------------------------------------------------------------------------- /graphgen/models/reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .csv_reader import CSVReader 2 | from .json_reader import JSONReader 3 | from .parquet_reader import ParquetReader 4 | from .pdf_reader import PDFReader 5 | from .pickle_reader import PickleReader 6 | from .rdf_reader import RDFReader 7 | from .txt_reader import TXTReader 8 | -------------------------------------------------------------------------------- /graphgen/models/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregated_generator import AggregatedGenerator 2 | from .atomic_generator import AtomicGenerator 3 | from .cot_generator import CoTGenerator 4 | from .multi_hop_generator import MultiHopGenerator 5 | from .quiz_generator import QuizGenerator 6 | from .vqa_generator import VQAGenerator 7 | -------------------------------------------------------------------------------- /graphgen/templates/generation/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregated_generation import AGGREGATED_GENERATION_PROMPT 2 | from .atomic_generation import ATOMIC_GENERATION_PROMPT 3 | from .cot_generation import COT_GENERATION_PROMPT 4 | from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT 5 | from .vqa_generation import VQA_GENERATION_PROMPT 6 | -------------------------------------------------------------------------------- /graphgen/models/storage/__init__.py: -------------------------------------------------------------------------------- 1 | from graphgen.models.storage.graph.kuzu_storage import KuzuStorage 2 | from graphgen.models.storage.graph.networkx_storage import NetworkXStorage 3 | from graphgen.models.storage.kv.json_storage import JsonKVStorage 4 | from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage 5 | 6 | from .rocksdb_cache import RocksDBCache 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 88 # black 默认 88 3 | include = '\.pyi?$' 4 | 5 | [tool.isort] 6 | profile = "black" # 一键适配 black 7 | line_length = 88 # 与 black 保持一致 8 | multi_line_output = 3 # black 偏好的括号换行风格 9 | include_trailing_comma = true 10 | force_grid_wrap = 0 11 | use_parentheses = true 12 | ensure_newline_before_comments = true -------------------------------------------------------------------------------- /graphgen/utils/wrap.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Any, Callable 3 | 4 | from .loop import create_event_loop 5 | 6 | 7 | def async_to_sync_method(func: Callable) -> Callable: 8 | @wraps(func) 9 | def wrapper(self, *args, **kwargs) -> Any: 10 | loop = create_event_loop() 11 | return loop.run_until_complete(func(self, *args, **kwargs)) 12 | 13 | return wrapper 14 | -------------------------------------------------------------------------------- /graphgen/templates/statement_judgement.py: -------------------------------------------------------------------------------- 1 | TEMPLATE: str = """Please determine if the following statement is correct. 2 | 3 | Note: 4 | 1. If the statement is correct, please reply with 'yes', otherwise reply with 'no'. 5 | 2. The answer should be either 'yes' or 'no', do not output any other content. 6 | 7 | Statement: 8 | {statement} 9 | Judgement: """ 10 | 11 | STATEMENT_JUDGEMENT_PROMPT = { 12 | "TEMPLATE": TEMPLATE 13 | } 14 | -------------------------------------------------------------------------------- /resources/nltk_data/tokenizers/punkt_tab/english/sent_starters.txt: -------------------------------------------------------------------------------- 1 | most 2 | he 3 | since 4 | so 5 | both 6 | these 7 | it 8 | nevertheless 9 | this 10 | indeed 11 | however 12 | instead 13 | under 14 | similarly 15 | some 16 | though 17 | while 18 | when 19 | in 20 | despite 21 | although 22 | nonetheless 23 | thus 24 | there 25 | if 26 | the 27 | nor 28 | separately 29 | moreover 30 | but 31 | they 32 | yet 33 | many 34 | according 35 | sales 36 | among 37 | meanwhile 38 | even 39 | i -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black 6 | name: black 7 | - repo: https://github.com/PyCQA/isort 8 | rev: 5.12.0 9 | hooks: 10 | - id: isort 11 | name: isort 12 | - repo: https://github.com/PyCQA/pylint 13 | rev: v2.15.5 14 | hooks: 15 | - id: pylint 16 | name: pylint 17 | entry: pylint 18 | language: system 19 | types: [python] 20 | -------------------------------------------------------------------------------- /.github/sync-config.yml: -------------------------------------------------------------------------------- 1 | sync: 2 | - source: graphgen/ 3 | dest: graphgen/ 4 | - source: resources/nltk_data/ 5 | dest: resources/nltk_data/ 6 | - source: resources/examples/ 7 | dest: resources/examples/ 8 | - source: resources/images/logo.png 9 | dest: resources/images/logo.png 10 | - source: webui/ 11 | dest: webui/ 12 | - source: webui/app.py 13 | dest: app.py 14 | - source: requirements.txt 15 | dest: requirements.txt 16 | - source: LICENSE 17 | dest: LICENSE 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | openai 3 | python-dotenv 4 | numpy 5 | networkx 6 | graspologic 7 | tiktoken 8 | pyecharts 9 | wikipedia 10 | tenacity 11 | nltk 12 | jieba 13 | plotly 14 | pandas 15 | gradio==5.44.1 16 | kaleido 17 | pyyaml 18 | langcodes 19 | requests 20 | fastapi 21 | trafilatura 22 | aiohttp 23 | socksio 24 | pydantic 25 | ray==2.52.1 26 | 27 | leidenalg 28 | igraph 29 | python-louvain 30 | 31 | # storage 32 | rocksdict 33 | 34 | # KG 35 | rdflib 36 | 37 | # Bioinformatics 38 | biopython 39 | 40 | # For visualization 41 | matplotlib 42 | -------------------------------------------------------------------------------- /graphgen/models/tokenizer/tiktoken_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tiktoken 4 | 5 | from graphgen.bases import BaseTokenizer 6 | 7 | 8 | class TiktokenTokenizer(BaseTokenizer): 9 | def __init__(self, model_name: str = "cl100k_base"): 10 | super().__init__(model_name) 11 | self.enc = tiktoken.get_encoding(self.model_name) 12 | 13 | def encode(self, text: str) -> List[int]: 14 | return self.enc.encode(text) 15 | 16 | def decode(self, token_ids: List[int]) -> str: 17 | return self.enc.decode(token_ids) 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Additional context** 24 | Add any other context about the problem here. 25 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Chen" 5 | given-names: "Zihong" 6 | - family-names: "Jiang" 7 | given-names: "Wanli" 8 | - family-names: "Li" 9 | given-names: "Jingzhe" 10 | - family-names: "Yuan" 11 | given-names: "Zhonghang" 12 | - family-names: "Wang" 13 | given-names: "Chenyang" 14 | - family-names: "Kong" 15 | given-names: "Huanjun" 16 | - family-names: "Dong" 17 | given-names: "Nanqing" 18 | title: "GraphGen" 19 | date-released: 2025-04-21 20 | url: "https://github.com/open-sciencelab/GraphGen" 21 | -------------------------------------------------------------------------------- /graphgen/bases/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_extractor import BaseExtractor 2 | from .base_generator import BaseGenerator 3 | from .base_kg_builder import BaseKGBuilder 4 | from .base_llm_wrapper import BaseLLMWrapper 5 | from .base_operator import BaseOperator 6 | from .base_partitioner import BasePartitioner 7 | from .base_reader import BaseReader 8 | from .base_searcher import BaseSearcher 9 | from .base_splitter import BaseSplitter 10 | from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace 11 | from .base_tokenizer import BaseTokenizer 12 | from .datatypes import Chunk, Config, Node, QAPair, Token 13 | -------------------------------------------------------------------------------- /webui/utils/cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import uuid 4 | 5 | 6 | def setup_workspace(folder): 7 | request_id = str(uuid.uuid4()) 8 | os.makedirs(folder, exist_ok=True) 9 | 10 | working_dir = os.path.join(folder, request_id) 11 | os.makedirs(working_dir, exist_ok=True) 12 | 13 | log_dir = os.path.join(folder, "logs") 14 | os.makedirs(log_dir, exist_ok=True) 15 | log_file = os.path.join(log_dir, f"{request_id}.log") 16 | 17 | return log_file, working_dir 18 | 19 | 20 | def cleanup_workspace(folder): 21 | if os.path.exists(folder): 22 | shutil.rmtree(folder) 23 | -------------------------------------------------------------------------------- /graphgen/bases/base_searcher.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List 3 | 4 | 5 | class BaseSearcher(ABC): 6 | """ 7 | Abstract base class for searching and retrieving data. 8 | """ 9 | 10 | @abstractmethod 11 | async def search(self, query: str, **kwargs) -> List[Dict[str, Any]]: 12 | """ 13 | Search for data based on the given query. 14 | 15 | :param query: The searcher query. 16 | :param kwargs: Additional keyword arguments for the searcher. 17 | :return: List of dictionaries containing the searcher results. 18 | """ 19 | -------------------------------------------------------------------------------- /graphgen/models/tokenizer/hf_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from transformers import AutoTokenizer 4 | 5 | from graphgen.bases import BaseTokenizer 6 | 7 | 8 | class HFTokenizer(BaseTokenizer): 9 | def __init__(self, model_name: str = "cl100k_base"): 10 | super().__init__(model_name) 11 | self.enc = AutoTokenizer.from_pretrained(self.model_name) 12 | 13 | def encode(self, text: str) -> List[int]: 14 | return self.enc.encode(text, add_special_tokens=False) 15 | 16 | def decode(self, token_ids: List[int]) -> str: 17 | return self.enc.decode(token_ids, skip_special_tokens=True) 18 | -------------------------------------------------------------------------------- /graphgen/bases/base_extractor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from graphgen.bases.base_llm_wrapper import BaseLLMWrapper 5 | 6 | 7 | class BaseExtractor(ABC): 8 | """ 9 | Extract information from given text. 10 | 11 | """ 12 | 13 | def __init__(self, llm_client: BaseLLMWrapper): 14 | self.llm_client = llm_client 15 | 16 | @abstractmethod 17 | async def extract(self, chunk: dict) -> Any: 18 | """Extract information from the given text""" 19 | 20 | @abstractmethod 21 | def build_prompt(self, text: str) -> str: 22 | """Build prompt for LLM based on the given text""" 23 | -------------------------------------------------------------------------------- /graphgen/operators/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_kg import BuildKGService 2 | from .chunk import ChunkService 3 | from .extract import ExtractService 4 | from .generate import GenerateService 5 | from .judge import JudgeService 6 | from .partition import PartitionService 7 | from .quiz import QuizService 8 | from .read import read 9 | from .search import search_all 10 | 11 | operators = { 12 | "read": read, 13 | "chunk": ChunkService, 14 | "build_kg": BuildKGService, 15 | "quiz": QuizService, 16 | "judge": JudgeService, 17 | "extract": ExtractService, 18 | "search": search_all, 19 | "partition": PartitionService, 20 | "generate": GenerateService, 21 | } 22 | -------------------------------------------------------------------------------- /graphgen/templates/__init__.py: -------------------------------------------------------------------------------- 1 | from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT 2 | from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT 3 | from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT 4 | from .generation import ( 5 | AGGREGATED_GENERATION_PROMPT, 6 | ATOMIC_GENERATION_PROMPT, 7 | COT_GENERATION_PROMPT, 8 | MULTI_HOP_GENERATION_PROMPT, 9 | VQA_GENERATION_PROMPT, 10 | ) 11 | from .kg import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT, MMKG_EXTRACTION_PROMPT 12 | from .question_generation import QUESTION_GENERATION_PROMPT 13 | from .search_judgement import SEARCH_JUDGEMENT_PROMPT 14 | from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /webui/test_api.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import gradio as gr 3 | 4 | def test_api_connection(api_base, api_key, model_name): 5 | client = OpenAI(api_key=api_key, base_url=api_base) 6 | try: 7 | response = client.chat.completions.create( 8 | model=model_name, 9 | messages=[{"role": "user", "content": "test"}], 10 | max_tokens=1 11 | ) 12 | if not response.choices or not response.choices[0].message: 13 | raise gr.Error(f"{model_name}: Invalid response from API") 14 | gr.Success(f"{model_name}: API connection successful") 15 | except Exception as e: 16 | raise gr.Error(f"{model_name}: API connection failed: {str(e)}") 17 | -------------------------------------------------------------------------------- /examples/search/search_rna_config.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - name: read_step 3 | op_key: read 4 | params: 5 | input_file: resources/input_examples/search_rna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples 6 | 7 | - name: search_step 8 | op_key: search 9 | deps: [read_step] # search_step depends on read_step 10 | params: 11 | data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral 12 | rnacentral_params: 13 | use_local_blast: true # whether to use local blast for RNA search 14 | local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension) 15 | -------------------------------------------------------------------------------- /graphgen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .calculate_confidence import yes_no_loss_entropy 2 | from .detect_lang import detect_if_chinese, detect_main_language 3 | from .device import pick_device 4 | from .format import ( 5 | handle_single_entity_extraction, 6 | handle_single_relationship_extraction, 7 | load_json, 8 | pack_history_conversations, 9 | split_string_by_multi_markers, 10 | write_json, 11 | ) 12 | from .hash import ( 13 | compute_args_hash, 14 | compute_content_hash, 15 | compute_dict_hash, 16 | compute_mm_hash, 17 | ) 18 | from .help_nltk import NLTKHelper 19 | from .log import CURRENT_LOGGER_VAR, logger, set_logger 20 | from .loop import create_event_loop 21 | from .run_concurrent import run_concurrent 22 | from .wrap import async_to_sync_method 23 | -------------------------------------------------------------------------------- /resources/nltk_data/tokenizers/punkt_tab/english/collocations.tab: -------------------------------------------------------------------------------- 1 | ##number## international 2 | ##number## rj 3 | ##number## commodities 4 | ##number## cooper 5 | b stewart 6 | ##number## genentech 7 | ##number## wedgestone 8 | i toussie 9 | ##number## pepper 10 | j fialka 11 | o ludcke 12 | ##number## insider 13 | ##number## aes 14 | i magnin 15 | ##number## credit 16 | ##number## corrections 17 | ##number## financing 18 | ##number## henley 19 | ##number## business 20 | ##number## pay-fone 21 | b wigton 22 | b edelman 23 | b levine 24 | ##number## leisure 25 | b smith 26 | j walter 27 | ##number## pegasus 28 | ##number## dividend 29 | j aron 30 | ##number## review 31 | ##number## abreast 32 | ##number## who 33 | ##number## letters 34 | ##number## colgate 35 | ##number## cbot 36 | ##number## notable 37 | ##number## zimmer -------------------------------------------------------------------------------- /examples/search/search_protein_config.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - name: read_step 3 | op_key: read 4 | params: 5 | input_file: resources/input_examples/search_protein_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples 6 | 7 | - name: search_step 8 | op_key: search 9 | deps: [read_step] # search_step depends on read_step 10 | params: 11 | data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot 12 | uniprot_params: 13 | use_local_blast: true # whether to use local blast for uniprot search 14 | local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot 15 | # options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database) 16 | -------------------------------------------------------------------------------- /examples/search/search_dna_config.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - name: read_step 3 | op_key: read 4 | params: 5 | input_file: resources/input_examples/search_dna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples 6 | 7 | - name: search_step 8 | op_key: search 9 | deps: [read_step] # search_step depends on read_step 10 | params: 11 | data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral 12 | ncbi_params: 13 | email: test@example.com # NCBI requires an email address 14 | tool: GraphGen # tool name for NCBI API 15 | use_local_blast: true # whether to use local blast for DNA search 16 | local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension) 17 | 18 | -------------------------------------------------------------------------------- /graphgen/models/llm/local/trt_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from graphgen.bases import BaseLLMWrapper 4 | from graphgen.bases.datatypes import Token 5 | 6 | 7 | # TODO: implement TensorRTWrapper methods 8 | class TensorRTWrapper(BaseLLMWrapper): 9 | """ 10 | Async inference backend based on TensorRT-LLM 11 | """ 12 | 13 | async def generate_answer( 14 | self, text: str, history: Optional[List[str]] = None, **extra: Any 15 | ) -> str: 16 | pass 17 | 18 | async def generate_topk_per_token( 19 | self, text: str, history: Optional[List[str]] = None, **extra: Any 20 | ) -> List[Token]: 21 | pass 22 | 23 | async def generate_inputs_prob( 24 | self, text: str, history: Optional[List[str]] = None, **extra: Any 25 | ) -> List[Token]: 26 | pass 27 | -------------------------------------------------------------------------------- /graphgen/_version.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | __version__ = "0.1.0.post20250930" 4 | short_version = __version__ 5 | 6 | 7 | def parse_version_info(version_str: str) -> Tuple: 8 | """Parse version from a string. 9 | 10 | Args: 11 | version_str (str): A string represents a version info. 12 | 13 | Returns: 14 | tuple: A sequence of integer and string represents version. 15 | """ 16 | _version_info = [] 17 | for x in version_str.split("."): 18 | if x.isdigit(): 19 | _version_info.append(int(x)) 20 | elif x.find("rc") != -1: 21 | patch_version = x.split("rc") 22 | _version_info.append(int(patch_version[0])) 23 | _version_info.append(f"rc{patch_version[1]}") 24 | return tuple(_version_info) 25 | 26 | 27 | version_info = parse_version_info(__version__) 28 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: "Pylint" 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | - 'master' 8 | - 'release-*' 9 | pull_request: 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: ["3.10", "3.11", "3.12"] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Setup Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | cache: 'pip' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | pip install pylint 31 | 32 | - name: Run Pylint 33 | run: | 34 | pylint --rcfile=.pylintrc baselines/ graphgen/ webui/ 35 | -------------------------------------------------------------------------------- /graphgen/models/evaluator/length_evaluator.py: -------------------------------------------------------------------------------- 1 | from graphgen.bases.datatypes import QAPair 2 | from graphgen.models.evaluator.base_evaluator import BaseEvaluator 3 | from graphgen.models.tokenizer import Tokenizer 4 | from graphgen.utils import create_event_loop 5 | 6 | 7 | class LengthEvaluator(BaseEvaluator): 8 | def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100): 9 | super().__init__(max_concurrent) 10 | self.tokenizer_name = tokenizer_name 11 | self.tokenizer = Tokenizer(model_name=self.tokenizer_name) 12 | 13 | async def evaluate_single(self, pair: QAPair) -> float: 14 | loop = create_event_loop() 15 | return await loop.run_in_executor(None, self._calculate_length, pair.answer) 16 | 17 | def _calculate_length(self, text: str) -> float: 18 | tokens = self.tokenizer.encode(text) 19 | return len(tokens) 20 | -------------------------------------------------------------------------------- /graphgen/models/reader/csv_reader.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import ray 4 | from ray.data import Dataset 5 | 6 | from graphgen.bases.base_reader import BaseReader 7 | 8 | 9 | class CSVReader(BaseReader): 10 | """ 11 | Reader for CSV files. 12 | Columns: 13 | - type: The type of the document (e.g., "text", "image", etc.) 14 | - if type is "text", "content" column must be present. 15 | """ 16 | 17 | def read(self, input_path: Union[str, List[str]]) -> Dataset: 18 | """ 19 | Read CSV files and return Ray Dataset. 20 | 21 | :param input_path: Path to CSV file or list of CSV files. 22 | :return: Ray Dataset containing validated and filtered data. 23 | """ 24 | 25 | ds = ray.data.read_csv(input_path) 26 | ds = ds.map_batches(self._validate_batch, batch_format="pandas") 27 | ds = ds.filter(self._should_keep_item) 28 | return ds 29 | -------------------------------------------------------------------------------- /webui/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | 5 | @dataclass 6 | class WebuiParams: 7 | """ 8 | GraphGen parameters 9 | """ 10 | 11 | if_trainee_model: bool 12 | upload_file: Any # gr.File 13 | tokenizer: str 14 | synthesizer_model: str 15 | synthesizer_url: str 16 | trainee_model: str 17 | trainee_url: str 18 | api_key: str 19 | trainee_api_key: str 20 | chunk_size: int 21 | chunk_overlap: int 22 | quiz_samples: int 23 | partition_method: str 24 | dfs_max_units: int 25 | bfs_max_units: int 26 | leiden_max_size: int 27 | leiden_use_lcc: bool 28 | leiden_random_seed: int 29 | ece_max_units: int 30 | ece_min_units: int 31 | ece_max_tokens: int 32 | ece_unit_sampling: str 33 | mode: str 34 | data_format: str 35 | rpm: int 36 | tpm: int 37 | token_counter: Any 38 | -------------------------------------------------------------------------------- /examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 4 | 5 | nodes: 6 | - id: read 7 | op_name: read 8 | type: source 9 | dependencies: [] 10 | params: 11 | input_path: 12 | - examples/input_examples/extract_demo.txt 13 | 14 | - id: chunk 15 | op_name: chunk 16 | type: map_batch 17 | dependencies: 18 | - read 19 | execution_params: 20 | replicas: 4 21 | params: 22 | chunk_size: 20480 # larger chunk size for better context 23 | chunk_overlap: 2000 24 | 25 | - id: extract 26 | op_name: extract 27 | type: map_batch 28 | dependencies: 29 | - chunk 30 | execution_params: 31 | replicas: 1 32 | batch_size: 128 33 | params: 34 | method: schema_guided 35 | schema_path: graphgen/templates/extraction/schemas/legal_contract.json 36 | -------------------------------------------------------------------------------- /graphgen/models/reader/txt_reader.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import ray 4 | from ray.data import Dataset 5 | 6 | from graphgen.bases.base_reader import BaseReader 7 | 8 | 9 | class TXTReader(BaseReader): 10 | def read( 11 | self, 12 | input_path: Union[str, List[str]], 13 | ) -> Dataset: 14 | """ 15 | Read text files from the specified input path. 16 | :param input_path: Path to the input text file or list of text files. 17 | :return: Ray Dataset containing the read text data. 18 | """ 19 | docs_ds = ray.data.read_binary_files( 20 | input_path, 21 | include_paths=False, 22 | ) 23 | 24 | docs_ds = docs_ds.map( 25 | lambda row: { 26 | "type": "text", 27 | self.text_column: row["bytes"].decode("utf-8"), 28 | } 29 | ) 30 | 31 | docs_ds = docs_ds.filter(self._should_keep_item) 32 | return docs_ds 33 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to PyPI 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.10' 18 | - name: Install pypa/build 19 | run: >- 20 | python -m 21 | pip install 22 | build 23 | --user 24 | - name: Build a binary wheel and a source tarball 25 | run: >- 26 | python -m 27 | build 28 | --sdist 29 | --wheel 30 | --outdir dist/ 31 | - name: Publish distribution 📦 to PyPI 32 | if: startsWith(github.ref, 'refs/tags') 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.pypi_password }} -------------------------------------------------------------------------------- /tests/integration_tests/models/splitter/test_character_splitter.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graphgen.models.splitter.character_splitter import CharacterSplitter 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "text,chunk_size,chunk_overlap,expected", 8 | [ 9 | ( 10 | "This is a test.\n\nThis is only a test.\n\nIn the event of an actual emergency...", 11 | 25, 12 | 5, 13 | [ 14 | "This is a test.", 15 | "This is only a test.", 16 | "In the event of an actual emergency...", 17 | ], 18 | ), 19 | ], 20 | ) 21 | def test_character_splitter(text, chunk_size, chunk_overlap, expected): 22 | splitter = CharacterSplitter( 23 | separator="\n\n", 24 | is_separator_regex=False, 25 | chunk_size=chunk_size, 26 | chunk_overlap=chunk_overlap, 27 | keep_separator=False, 28 | ) 29 | chunks = splitter.split_text(text) 30 | assert chunks == expected 31 | -------------------------------------------------------------------------------- /graphgen/utils/hash.py: -------------------------------------------------------------------------------- 1 | from hashlib import md5 2 | 3 | 4 | def compute_args_hash(*args): 5 | return md5(str(args).encode()).hexdigest() 6 | 7 | 8 | def compute_content_hash(content, prefix: str = ""): 9 | return prefix + md5(content.encode()).hexdigest() 10 | 11 | 12 | def compute_mm_hash(item, prefix: str = ""): 13 | if item.get("type") == "text" and item.get("text"): 14 | content = item["text"].strip() 15 | elif item.get("type") == "image" and item.get("img_path"): 16 | content = f"image:{item['img_path']}" 17 | elif item.get("type") == "table" and item.get("table_body"): 18 | content = f"table:{item['table_body']}" 19 | elif item.get("type") == "equation" and item.get("text"): 20 | content = f"equation:{item['text']}" 21 | else: 22 | content = str(item) 23 | return prefix + md5(content.encode()).hexdigest() 24 | 25 | 26 | def compute_dict_hash(d: dict, prefix: str = ""): 27 | items = tuple(sorted(d.items())) 28 | return prefix + md5(str(items).encode()).hexdigest() 29 | -------------------------------------------------------------------------------- /graphgen/templates/coreference_resolution.py: -------------------------------------------------------------------------------- 1 | TEMPLATE_ZH: str = """请根据参考文本识别并消解文本中的指代词,明确每个代词所指代的具体实体,并直接输出消解后的文本。 2 | 3 | -示例- 4 | 输入: 5 | 小明和小红一起去公园。她们玩得很开心。之后,他们去吃冰淇淋。 6 | 输出: 7 | 小明和小红一起去公园。小明和小红玩得很开心。之后,小明和小红去吃冰淇淋。 8 | 9 | -真实数据- 10 | 参考文本: 11 | {reference} 12 | 输入: 13 | {input_sentence} 14 | 请直接输出改写后的句子,不要输出任何额外信息。 15 | 输出: 16 | """ 17 | 18 | TEMPLATE_EN: str = """Please identify and resolve the pronouns in the reference text, \ 19 | specify the specific entities referred to by each pronoun, and directly output the resolved text. 20 | 21 | -Example- 22 | Input: 23 | John and Mary went to the park. They had a great time. Later, they went to eat ice cream. 24 | Output: 25 | John and Mary went to the park. John and Mary had a great time. Later, John and Mary went to eat ice cream. 26 | 27 | -Real Data- 28 | Reference text: 29 | {reference} 30 | Input: 31 | {input_sentence} 32 | Please directly output the rewritten sentence without any additional information. 33 | Output: 34 | """ 35 | 36 | COREFERENCE_RESOLUTION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH} 37 | -------------------------------------------------------------------------------- /graphgen/utils/loop.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from .log import logger 4 | 5 | 6 | def create_event_loop() -> asyncio.AbstractEventLoop: 7 | """ 8 | Ensure that there is always an event loop available. 9 | 10 | This function tries to get the current event loop. If the current event loop is closed or does not exist, 11 | it creates a new event loop and sets it as the current event loop. 12 | 13 | Returns: 14 | asyncio.AbstractEventLoop: The current or newly created event loop. 15 | """ 16 | try: 17 | # Try to get the current event loop 18 | current_loop = asyncio.get_event_loop() 19 | if current_loop.is_closed(): 20 | raise RuntimeError("Event loop is closed.") 21 | return current_loop 22 | 23 | except RuntimeError: 24 | # If no event loop exists or it is closed, create a new one 25 | logger.info("Creating a new event loop in main thread.") 26 | new_loop = asyncio.new_event_loop() 27 | asyncio.set_event_loop(new_loop) 28 | return new_loop 29 | -------------------------------------------------------------------------------- /webui/utils/preview_file.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | 4 | import gradio as gr 5 | import pandas as pd 6 | 7 | 8 | def preview_file(file): 9 | if file is None: 10 | return gr.update(visible=False), gr.update(visible=False) 11 | 12 | path = file.name 13 | ext = os.path.splitext(path)[1].lower() 14 | 15 | try: 16 | if ext == ".csv": 17 | df = pd.read_csv(path, nrows=10) 18 | return gr.update(visible=False), gr.update(value=df, visible=True) 19 | with codecs.open(path, "r", encoding="utf-8") as f: 20 | text = f.read(5000) 21 | if len(text) == 5000: 22 | text += "\n\n... (truncated at 5000 chars)" 23 | return gr.update( 24 | value=text, visible=True, language="json" if ext != ".txt" else None 25 | ), gr.update(visible=False) 26 | except Exception as e: # pylint: disable=broad-except 27 | return gr.update( 28 | value=f"Preview failed: {e}", visible=True, language=None 29 | ), gr.update(visible=False) 30 | -------------------------------------------------------------------------------- /graphgen/models/reader/parquet_reader.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import ray 4 | from ray.data import Dataset 5 | 6 | from graphgen.bases.base_reader import BaseReader 7 | 8 | 9 | class ParquetReader(BaseReader): 10 | """ 11 | Read parquet files, requiring the schema to be restored to List[Dict[str, Any]]. 12 | Columns: 13 | - type: The type of the document (e.g., "text", "image", etc.) 14 | - if type is "text", "content" column must be present. 15 | """ 16 | 17 | def read(self, input_path: Union[str, List[str]]) -> Dataset: 18 | """ 19 | Read Parquet files using Ray Data. 20 | 21 | :param input_path: Path to Parquet file or list of Parquet files. 22 | :return: Ray Dataset containing validated documents. 23 | """ 24 | if not ray.is_initialized(): 25 | ray.init() 26 | 27 | ds = ray.data.read_parquet(input_path) 28 | ds = ds.map_batches(self._validate_batch, batch_format="pandas") 29 | ds = ds.filter(self._should_keep_item) 30 | return ds 31 | -------------------------------------------------------------------------------- /graphgen/operators/preprocess/resolute_coreference.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from graphgen.bases.datatypes import Chunk 4 | from graphgen.models import OpenAIClient 5 | from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT 6 | from graphgen.utils import detect_main_language 7 | 8 | 9 | async def resolute_coreference( 10 | llm_client: OpenAIClient, chunks: List[Chunk] 11 | ) -> List[Chunk]: 12 | """ 13 | Resolute conference 14 | 15 | :param llm_client: LLM model 16 | :param chunks: List of chunks 17 | :return: List of chunks 18 | """ 19 | 20 | if len(chunks) == 0: 21 | return chunks 22 | 23 | results = [chunks[0]] 24 | 25 | for _, chunk in enumerate(chunks[1:]): 26 | language = detect_main_language(chunk.content) 27 | result = await llm_client.generate_answer( 28 | COREFERENCE_RESOLUTION_PROMPT[language].format( 29 | reference=results[0].content, input_sentence=chunk.content 30 | ) 31 | ) 32 | results.append(Chunk(id=chunk.id, content=result)) 33 | 34 | return results 35 | -------------------------------------------------------------------------------- /graphgen/models/splitter/character_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, List 3 | 4 | from graphgen.bases.base_splitter import BaseSplitter 5 | 6 | 7 | class CharacterSplitter(BaseSplitter): 8 | """Splitting text that looks at characters.""" 9 | 10 | def __init__( 11 | self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any 12 | ) -> None: 13 | """Create a new TextSplitter.""" 14 | super().__init__(**kwargs) 15 | self._separator = separator 16 | self._is_separator_regex = is_separator_regex 17 | 18 | def split_text(self, text: str) -> List[str]: 19 | """Split incoming text and return chunks.""" 20 | # First we naively chunk the large input into a bunch of smaller ones. 21 | separator = ( 22 | self._separator if self._is_separator_regex else re.escape(self._separator) 23 | ) 24 | splits = self._split_text_with_regex(text, separator, self.keep_separator) 25 | _separator = "" if self.keep_separator else self._separator 26 | return self._merge_splits(splits, _separator) 27 | -------------------------------------------------------------------------------- /graphgen/templates/question_generation.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | 3 | 4 | # TODO: 修改这里的prompt 5 | TEMPLATE_MULTI_EN = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with one tag of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact. 6 | 7 | Here is the format you should follow for your response: 8 | Question: 9 | Answer: 10 | 11 | Here is the article you need to rephrase: 12 | {doc} 13 | """ 14 | 15 | TEMPLATE_MULTI_ZH = """你是一位助手,帮助阅读一篇文章,然后以问答格式重述它。用户将为您提供一篇带有内容的文章。你需要以一个标签"问题:..."为开头,接着是"答案:...",生成一篇与原文章相同的问答格式的重述。请确保保持文章的意义和每个内容不变。 16 | 17 | 以下是你应该遵循的响应格式: 18 | 问题: <问题> 19 | 答案: <答案> 20 | 21 | 以下是你需要重述的文章: 22 | {doc} 23 | """ 24 | 25 | QUESTION_GENERATION_PROMPT = { 26 | "English": { 27 | "MULTI_TEMPLATE": TEMPLATE_MULTI_EN, 28 | }, 29 | "Chinese": { 30 | "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH, 31 | }, 32 | } 33 | -------------------------------------------------------------------------------- /examples/input_examples/search_rna_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"type": "text", "content": "hsa-let-7a-1"} 2 | {"type": "text", "content": "XIST regulator"} 3 | {"type": "text", "content": "URS0000123456"} 4 | {"type": "text", "content": "URS0000000001"} 5 | {"type": "text", "content": "URS0000000787"} 6 | {"type": "text", "content": "GCAGTTCTCAGCCATGACAGATGGGAGTTTCGGCCCAATTGACCAGTATTCCTTACTGATAAGAGACACTGACCATGGAGTGGTTCTGGTGAGATGACATGACCCTCGTGAAGGGGCCTGAAGCTTCATTGTGTTTGTGTATGTTTCTCTCTTCAAAAATATTCATGACTTCTCCTGTAGCTTGATAAATATGTATATTTACACACTGCA"} 7 | {"type": "text", "content": ">query\nCUCCUUUGACGUUAGCGGCGGACGGGUUAGUAACACGUGGGUAACCUACCUAUAAGACUGGGAUAACUUCGGGAAACCGGAGCUAAUACCGGAUAAUAUUUCGAACCGCAUGGUUCGAUAGUGAAAGAUGGUUUUGCUAUCACUUAUAGAUGGACCCGCGCCGUAUUAGCUAGUUGGUAAGGUAACGGCUUACCAAGGCGACGAUACGUAGCCGACCUGAGAGGGUGAUCGGCCACACUGGAACUGAGACACGGUCCAGACUCCUACGGGAGGCAGCAGGGG"} 8 | {"type": "text", "content": "CUCCUUUGACGUUAGCGGCGGACGGGUUAGUAACACGUGGGUAACCUACCUAUAAGACUGGGAUAACUUCGGGAAACCGGAGCUAAUACCGGAUAAUAUUUCGAACCGCAUGGUUCGAUAGUGAAAGAUGGUUUUGCUAUCACUUAUAGAUGGACCCGCGCCGUAUUAGCUAGUUGGUAAGGUAACGGCUUACCAAGGCGACGAUACGUAGCCGACCUGAGAGGGUGAUCGGCCACACUGGAACUGAGACACGGUCCAGACUCCUACGGGAGGCAGCAGGGG"} 9 | -------------------------------------------------------------------------------- /.github/contributing.md: -------------------------------------------------------------------------------- 1 | ## Contribution Guide 2 | Here are the steps to contribute to this project: 3 | 4 | 1. Star this repository. 5 | 2. Fork this repository. 6 | 7 | Type the following command on Git bash console: 8 | ```bash 9 | git clone https://github.com/open-sciencelab/GraphGen.git 10 | ``` 11 | 12 | 3. Create a new branch 13 | 14 | Now before making changes to the files, go to your terminal under the repo you just cloned, and type the following: 15 | 16 | ```bash 17 | git checkout -b add-my-name 18 | ``` 19 | 20 | By running the above command, you just created a new branch called add-my-name and checked it out, what this does is that it creates a new branch with the commit history of the master branch or the branch that you were on previously. 21 | 22 | 4. Make your changes and push your code. 23 | 24 | ``` 25 | git add . 26 | git commit -m "xxx" 27 | git push 28 | ``` 29 | 30 | This will create a new commit with the changes you made. 31 | 32 | 5. Now create a pull request and add the title. 33 | 34 | Sit back and relax while your pull request is being reviewed and merged. 35 | -------------------------------------------------------------------------------- /graphgen/utils/detect_lang.py: -------------------------------------------------------------------------------- 1 | def detect_main_language(text): 2 | """ 3 | Detect the main language of the text, 'zh' for Chinese, 'en' for English 4 | 5 | :param text: 6 | :return: 7 | """ 8 | assert isinstance(text, str) 9 | 10 | def is_chinese_char(char): 11 | return "\u4e00" <= char <= "\u9fff" 12 | 13 | def is_english_char(char): 14 | return char.isascii() and char.isalpha() 15 | 16 | text = "".join(char for char in text if char.strip()) 17 | 18 | chinese_count = sum(1 for char in text if is_chinese_char(char)) 19 | english_count = sum(1 for char in text if is_english_char(char)) 20 | 21 | total = chinese_count + english_count 22 | if total == 0: 23 | return "en" 24 | 25 | chinese_ratio = chinese_count / total 26 | 27 | if chinese_ratio >= 0.5: 28 | return "zh" 29 | return "en" 30 | 31 | 32 | def detect_if_chinese(text): 33 | """ 34 | Detect if the text contains any Chinese characters 35 | 36 | :param text: 37 | :return: 38 | """ 39 | 40 | assert isinstance(text, str) 41 | return any("\u4e00" <= char <= "\u9fff" for char in text) 42 | -------------------------------------------------------------------------------- /graphgen/models/llm/local/tgi_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from graphgen.bases import BaseLLMWrapper 4 | from graphgen.bases.datatypes import Token 5 | 6 | 7 | # TODO: implement TGIWrapper methods 8 | class TGIWrapper(BaseLLMWrapper): 9 | """ 10 | Async inference backend based on TGI (Text-Generation-Inference) 11 | """ 12 | 13 | def __init__( 14 | self, 15 | model_url: str, # e.g. "http://localhost:8080" 16 | temperature: float = 0.0, 17 | top_p: float = 1.0, 18 | topk: int = 5, 19 | **kwargs: Any 20 | ): 21 | super().__init__(temperature=temperature, top_p=top_p, **kwargs) 22 | 23 | async def generate_answer( 24 | self, text: str, history: Optional[List[str]] = None, **extra: Any 25 | ) -> str: 26 | pass 27 | 28 | async def generate_topk_per_token( 29 | self, text: str, history: Optional[List[str]] = None, **extra: Any 30 | ) -> List[Token]: 31 | pass 32 | 33 | async def generate_inputs_prob( 34 | self, text: str, history: Optional[List[str]] = None, **extra: Any 35 | ) -> List[Token]: 36 | pass 37 | -------------------------------------------------------------------------------- /examples/input_examples/search_dna_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"type": "text", "content": "NG_033923"} 2 | {"type": "text", "content": "NG_056118"} 3 | {"type": "text", "content": ">query\nACTCAATTGTCCCAGCAGCATCTACCGAAAAGCCCCCTTGCTGTTCCTGCCAACTTGAAGCCCGGAGGCCTGCTGGGAGGAGGAATTCTAAATGACAAGTATGCCTGGAAAGCTGTGGTCCAAGGCCGTTTTTGCCGTCAGCAGGATCTCCAGAACCAAAGGGAGGACACAGCTCTTCTTAAAACTGAAGGTATTTATGGCTGACATAAAATGAGATTTGATTTGGGCAGGAAATGCGCTTATGTGTACAAAGAATAATACTGACTCCTGGCAGCAAACCAAACAAAACCAGAGTAAGGTGGAGAAAGGTAACGTGTGCCCACGGAAACAGTGGCACAATGTGTGCCTAATTCCAAAGCAGCCGTCCTGCTTAGGCCACTAGTCACGGCGGCTCTGTGATGCTGTACTCCTCAAGGATTTGAACTAATGAAAAGTAAATAAATACCAGTAAAAGTGGATTTGTAAAAAGAAAAGAAAAATGATAGGAAAAGCCCCTTTACCATATGTCAAGGGTTTATGCTG"} 4 | {"type": "text", "content": "ACTCAATTGTCCCAGCAGCATCTACCGAAAAGCCCCCTTGCTGTTCCTGCCAACTTGAAGCCCGGAGGCCTGCTGGGAGGAGGAATTCTAAATGACAAGTATGCCTGGAAAGCTGTGGTCCAAGGCCGTTTTTGCCGTCAGCAGGATCTCCAGAACCAAAGGGAGGACACAGCTCTTCTTAAAACTGAAGGTATTTATGGCTGACATAAAATGAGATTTGATTTGGGCAGGAAATGCGCTTATGTGTACAAAGAATAATACTGACTCCTGGCAGCAAACCAAACAAAACCAGAGTAAGGTGGAGAAAGGTAACGTGTGCCCACGGAAACAGTGGCACAATGTGTGCCTAATTCCAAAGCAGCCGTCCTGCTTAGGCCACTAGTCACGGCGGCTCTGTGATGCTGTACTCCTCAAGGATTTGAACTAATGAAAAGTAAATAAATACCAGTAAAAGTGGATTTGTAAAAAGAAAAGAAAAATGATAGGAAAAGCCCCTTTACCATATGTCAAGGGTTTATGCTG"} 5 | -------------------------------------------------------------------------------- /graphgen/utils/run_concurrent.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Awaitable, Callable, List, TypeVar 3 | 4 | from tqdm.asyncio import tqdm as tqdm_async 5 | 6 | from graphgen.utils.log import logger 7 | 8 | from .loop import create_event_loop 9 | 10 | T = TypeVar("T") 11 | R = TypeVar("R") 12 | 13 | 14 | def run_concurrent( 15 | coro_fn: Callable[[T], Awaitable[R]], 16 | items: List[T], 17 | *, 18 | desc: str = "processing", 19 | unit: str = "item", 20 | ) -> List[R]: 21 | async def _run_all(): 22 | tasks = [asyncio.create_task(coro_fn(item)) for item in items] 23 | 24 | results = [] 25 | pbar = tqdm_async(total=len(items), desc=desc, unit=unit) 26 | 27 | for future in asyncio.as_completed(tasks): 28 | try: 29 | result = await future 30 | results.append(result) 31 | except Exception as e: 32 | logger.exception("Task failed: %s", e) 33 | results.append(e) 34 | 35 | pbar.update(1) 36 | 37 | pbar.close() 38 | return [res for res in results if not isinstance(res, Exception)] 39 | 40 | loop = create_event_loop() 41 | try: 42 | return loop.run_until_complete(_run_all()) 43 | finally: 44 | loop.close() 45 | -------------------------------------------------------------------------------- /graphgen/models/splitter/markdown_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from graphgen.models.splitter.recursive_character_splitter import ( 4 | RecursiveCharacterSplitter, 5 | ) 6 | 7 | 8 | class MarkdownTextRefSplitter(RecursiveCharacterSplitter): 9 | """Attempts to chunk the text along Markdown-formatted headings.""" 10 | 11 | def __init__(self, **kwargs: Any) -> None: 12 | """Initialize a MarkdownTextRefSplitter.""" 13 | separators = [ 14 | # First, try to chunk along Markdown headings (starting with level 2) 15 | "\n#{1,6} ", 16 | # Note the alternative syntax for headings (below) is not handled here 17 | # Heading level 2 18 | # --------------- 19 | # End of code block 20 | "```\n", 21 | # Horizontal lines 22 | "\n\\*\\*\\*+\n", 23 | "\n---+\n", 24 | "\n___+\n", 25 | # Note: horizontal lines defined by three or more of ***, ---, or ___ 26 | # are handled by the regexes above, but alternative syntaxes (e.g., with spaces) 27 | # are not handled. 28 | "\n\n", 29 | "\n", 30 | " ", 31 | "", 32 | ] 33 | super().__init__(separators=separators, **kwargs) 34 | -------------------------------------------------------------------------------- /graphgen/templates/generation/atomic_generation.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | TEMPLATE_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. 3 | The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. 4 | For example: 5 | Question: What is the effect of overexpressing the BG1 gene on grain size and development? 6 | Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. 7 | 8 | Question: What role does TAC4 play in the gravitropism of rice shoots? 9 | Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. 10 | 11 | Here is the text passage you need to generate a QA pair for: 12 | {context} 13 | """ 14 | 15 | TEMPLATE_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 16 | 答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 17 | 例如: 18 | 问题:过表达BG1基因对谷粒大小和发育有什么影响? 19 | 答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 20 | 21 | 问题:TAC4在水稻茎的重力性状中扮演什么角色? 22 | 答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 23 | 24 | 以下是你需要为其生成QA对的文本段落: 25 | {context} 26 | """ 27 | 28 | 29 | ATOMIC_GENERATION_PROMPT = { 30 | "en": TEMPLATE_EN, 31 | "zh": TEMPLATE_ZH, 32 | } 33 | -------------------------------------------------------------------------------- /examples/generate/generate_atomic_qa/atomic_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | graph_backend: kuzu # graph database backend, support: kuzu, networkx 4 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 5 | 6 | nodes: 7 | - id: read 8 | op_name: read 9 | type: source 10 | dependencies: [] 11 | params: 12 | input_path: 13 | - examples/input_examples/json_demo.json 14 | 15 | - id: chunk 16 | op_name: chunk 17 | type: map_batch 18 | dependencies: 19 | - read 20 | execution_params: 21 | replicas: 4 22 | params: 23 | chunk_size: 1024 24 | chunk_overlap: 100 25 | 26 | - id: build_kg 27 | op_name: build_kg 28 | type: map_batch 29 | execution_params: 30 | replicas: 1 31 | batch_size: 128 32 | dependencies: 33 | - chunk 34 | 35 | - id: partition 36 | op_name: partition 37 | type: aggregate 38 | dependencies: 39 | - build_kg 40 | params: 41 | method: dfs 42 | method_params: 43 | max_units_per_community: 1 44 | 45 | - id: generate 46 | op_name: generate 47 | type: map_batch 48 | dependencies: 49 | - partition 50 | execution_params: 51 | replicas: 1 52 | batch_size: 128 53 | params: 54 | method: atomic 55 | data_format: Alpaca 56 | -------------------------------------------------------------------------------- /graphgen/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator 2 | from .generator import ( 3 | AggregatedGenerator, 4 | AtomicGenerator, 5 | CoTGenerator, 6 | MultiHopGenerator, 7 | QuizGenerator, 8 | VQAGenerator, 9 | ) 10 | from .kg_builder import LightRAGKGBuilder, MMKGBuilder 11 | from .llm import HTTPClient, OllamaClient, OpenAIClient 12 | from .partitioner import ( 13 | AnchorBFSPartitioner, 14 | BFSPartitioner, 15 | DFSPartitioner, 16 | ECEPartitioner, 17 | LeidenPartitioner, 18 | ) 19 | from .reader import ( 20 | CSVReader, 21 | JSONReader, 22 | ParquetReader, 23 | PDFReader, 24 | PickleReader, 25 | RDFReader, 26 | TXTReader, 27 | ) 28 | from .searcher.db.ncbi_searcher import NCBISearch 29 | from .searcher.db.rnacentral_searcher import RNACentralSearch 30 | from .searcher.db.uniprot_searcher import UniProtSearch 31 | from .searcher.kg.wiki_search import WikiSearch 32 | from .searcher.web.bing_search import BingSearch 33 | from .searcher.web.google_search import GoogleSearch 34 | from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter 35 | from .storage import ( 36 | JsonKVStorage, 37 | KuzuStorage, 38 | NetworkXStorage, 39 | RocksDBCache, 40 | RocksDBKVStorage, 41 | ) 42 | from .tokenizer import Tokenizer 43 | -------------------------------------------------------------------------------- /graphgen/models/storage/rocksdb_cache.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Iterator, Optional 3 | 4 | # rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it 5 | # pylint: disable=no-name-in-module 6 | from rocksdict import Rdict 7 | 8 | 9 | class RocksDBCache: 10 | def __init__(self, cache_dir: str): 11 | self.db_path = Path(cache_dir) 12 | self.db = Rdict(str(self.db_path)) 13 | 14 | def get(self, key: str) -> Optional[Any]: 15 | return self.db.get(key) 16 | 17 | def set(self, key: str, value: Any): 18 | self.db[key] = value 19 | 20 | def delete(self, key: str): 21 | try: 22 | del self.db[key] 23 | except KeyError: 24 | # If the key does not exist, do nothing (deletion is idempotent for caches) 25 | pass 26 | 27 | def close(self): 28 | if hasattr(self, "db") and self.db is not None: 29 | self.db.close() 30 | self.db = None 31 | 32 | def __del__(self): 33 | # Ensure the database is closed when the object is destroyed 34 | self.close() 35 | 36 | def __enter__(self): 37 | return self 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self.close() 41 | 42 | def __iter__(self) -> Iterator[str]: 43 | return iter(self.db.keys()) 44 | -------------------------------------------------------------------------------- /baselines/EntiGraph/inference/devapi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dotenv 3 | from openai import AsyncOpenAI 4 | 5 | dotenv.load_dotenv() 6 | 7 | async def gptqa(prompt: str, 8 | openai_model_name: str, 9 | system_message: str, 10 | json_format: bool = False, 11 | temp: float = 1.0): 12 | client = AsyncOpenAI( 13 | api_key=os.getenv("SYNTHESIZER_API_KEY"), 14 | base_url=os.getenv("SYNTHESIZER_BASE_URL") 15 | ) 16 | openai_model_name = openai_model_name or os.getenv("SYNTHESIZER_MODEL") 17 | 18 | if json_format: 19 | completion = await client.chat.completions.create( 20 | model=openai_model_name, 21 | temperature=temp, 22 | response_format={ "type": "json_object" }, 23 | messages=[ 24 | {"role": "system", 25 | "content": system_message}, 26 | {"role": "user", 27 | "content": prompt}, 28 | ]) 29 | else: 30 | completion = await client.chat.completions.create( 31 | model=openai_model_name, 32 | temperature=temp, 33 | messages=[ 34 | {"role": "system", 35 | "content": system_message}, 36 | {"role": "user", 37 | "content": prompt}, 38 | ]) 39 | return completion.choices[0].message.content 40 | -------------------------------------------------------------------------------- /examples/generate/generate_cot_qa/cot_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | graph_backend: kuzu # graph database backend, support: kuzu, networkx 4 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 5 | 6 | nodes: 7 | - id: read 8 | op_name: read 9 | type: source 10 | dependencies: [] 11 | params: 12 | input_path: 13 | - examples/input_examples/txt_demo.txt 14 | 15 | - id: chunk 16 | op_name: chunk 17 | type: map_batch 18 | dependencies: 19 | - read 20 | execution_params: 21 | replicas: 4 22 | params: 23 | chunk_size: 1024 24 | chunk_overlap: 100 25 | 26 | - id: build_kg 27 | op_name: build_kg 28 | type: map_batch 29 | execution_params: 30 | replicas: 1 31 | batch_size: 128 32 | dependencies: 33 | - chunk 34 | 35 | - id: partition 36 | op_name: partition 37 | type: aggregate 38 | dependencies: 39 | - build_kg 40 | params: 41 | method: leiden 42 | method_params: 43 | max_size: 20 44 | use_lcc: false 45 | random_seed: 42 46 | 47 | - id: generate 48 | op_name: generate 49 | type: map_batch 50 | dependencies: 51 | - partition 52 | execution_params: 53 | replicas: 1 54 | batch_size: 128 55 | params: 56 | method: cot 57 | data_format: Sharegpt 58 | -------------------------------------------------------------------------------- /tests/integration_tests/models/splitter/test_markdown_splitter.py: -------------------------------------------------------------------------------- 1 | from graphgen.models.splitter.markdown_splitter import MarkdownTextRefSplitter 2 | 3 | 4 | def test_split_markdown_structures(): 5 | md = ( 6 | "# Header1\n\n" 7 | "Some introduction here.\n\n" 8 | "## Header2\n\n" 9 | "```python\nprint('hello')\n```\n" 10 | "Paragraph under code block.\n\n" 11 | "***\n" 12 | "### Header3\n\n" 13 | "More text after horizontal rule.\n\n" 14 | "#### Header4\n\n" 15 | "Final paragraph." 16 | ) 17 | 18 | splitter = MarkdownTextRefSplitter( 19 | chunk_size=120, 20 | chunk_overlap=0, 21 | keep_separator=True, 22 | is_separator_regex=True, 23 | ) 24 | chunks = splitter.split_text(md) 25 | assert len(chunks) > 1 26 | 27 | for chk in chunks: 28 | assert len(chk) <= 120 29 | 30 | assert any("## Header2" in c for c in chunks) 31 | assert any("***" in c for c in chunks) 32 | assert any("```" in c for c in chunks) 33 | 34 | 35 | def test_split_size_less_than_single_char(): 36 | """Edge case: chunk_size is smaller than any segment; should still return the original text.""" 37 | short = "# A\n\nB" 38 | splitter = MarkdownTextRefSplitter(chunk_size=1, chunk_overlap=0) 39 | chunks = splitter.split_text(short) 40 | assert "".join(chunks) == short 41 | -------------------------------------------------------------------------------- /graphgen/templates/kg/kg_summarization.py: -------------------------------------------------------------------------------- 1 | TEMPLATE_EN = """You are an NLP expert responsible for generating a comprehensive summary of the data provided below. 2 | Given one entity or relationship, and a list of descriptions, all related to the same entity or relationship. 3 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. 4 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. 5 | Make sure it is written in third person, and include the entity names so we the have full context. 6 | Use English as output language. 7 | 8 | ####### 9 | -Data- 10 | Entities: {entity_name} 11 | Description List: {description_list} 12 | ####### 13 | Output: 14 | """ 15 | 16 | TEMPLATE_ZH = """你是一个NLP专家,负责根据以下提供的数据生成综合摘要。 17 | 给定一个实体或关系,以及一系列描述,所有描述都与同一实体或关系相关。 18 | 请将所有这些描述整合成一个综合描述。确保包含所有描述中收集的信息。 19 | 如果提供的描述是矛盾的,请解决这些矛盾并提供一个连贯的总结。 20 | 确保以第三人称写作,并包含实体名称,以便我们有完整的上下文。 21 | 使用中文作为输出语言。 22 | 23 | ####### 24 | -数据- 25 | 实体:{entity_name} 26 | 描述列表:{description_list} 27 | ####### 28 | 输出: 29 | """ 30 | 31 | 32 | KG_SUMMARIZATION_PROMPT = { 33 | "zh": {"TEMPLATE": TEMPLATE_ZH}, 34 | "en": {"TEMPLATE": TEMPLATE_EN}, 35 | "FORMAT": { 36 | "tuple_delimiter": "<|>", 37 | "record_delimiter": "##", 38 | "completion_delimiter": "<|COMPLETE|>", 39 | }, 40 | } 41 | -------------------------------------------------------------------------------- /graphgen/utils/help_nltk.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Optional 3 | import nltk 4 | import jieba 5 | 6 | resource_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "resources") 7 | 8 | 9 | class NLTKHelper: 10 | _stopwords: Dict[str, Optional[List[str]]] = { 11 | "english": None, 12 | "chinese": None, 13 | } 14 | 15 | def __init__(self): 16 | jieba.initialize() 17 | 18 | def get_stopwords(self, lang: str) -> List[str]: 19 | nltk.data.path.append(os.path.join(resource_path, "nltk_data")) 20 | if self._stopwords[lang] is None: 21 | try: 22 | nltk.data.find("corpora/stopwords") 23 | except LookupError: 24 | nltk.download("stopwords", download_dir=os.path.join(resource_path, "nltk_data")) 25 | 26 | self._stopwords[lang] = nltk.corpus.stopwords.words(lang) 27 | return self._stopwords[lang] 28 | 29 | @staticmethod 30 | def word_tokenize(text: str, lang: str) -> List[str]: 31 | if lang == "zh": 32 | return jieba.lcut(text) 33 | nltk.data.path.append(os.path.join(resource_path, "nltk_data")) 34 | try: 35 | nltk.data.find("tokenizers/punkt_tab") 36 | except LookupError: 37 | nltk.download("punkt_tab", download_dir=os.path.join(resource_path, "nltk_data")) 38 | 39 | return nltk.word_tokenize(text) 40 | -------------------------------------------------------------------------------- /examples/generate/generate_vqa/vqa_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | graph_backend: kuzu # graph database backend, support: kuzu, networkx 4 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 5 | 6 | nodes: 7 | - id: read 8 | op_name: read 9 | type: source 10 | dependencies: [] 11 | params: 12 | input_path: 13 | - examples/input_examples/vqa_demo.json 14 | modalities: 15 | - text 16 | - image 17 | 18 | - id: chunk 19 | op_name: chunk 20 | type: map_batch 21 | dependencies: 22 | - read 23 | execution_params: 24 | replicas: 4 25 | params: 26 | chunk_size: 1024 27 | chunk_overlap: 100 28 | 29 | - id: build_kg 30 | op_name: build_kg 31 | type: map_batch 32 | dependencies: 33 | - chunk 34 | execution_params: 35 | replicas: 1 36 | batch_size: 128 37 | 38 | - id: partition 39 | op_name: partition 40 | type: aggregate 41 | dependencies: 42 | - build_kg 43 | params: 44 | method: anchor_bfs 45 | method_params: 46 | anchor_type: image 47 | max_units_per_community: 10 48 | 49 | - id: generate 50 | op_name: generate 51 | type: map_batch 52 | dependencies: 53 | - partition 54 | execution_params: 55 | replicas: 1 56 | batch_size: 128 57 | params: 58 | method: vqa 59 | data_format: ChatML -------------------------------------------------------------------------------- /tests/integration_tests/models/reader/test_mineru_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from graphgen.models.reader.pdf_reader import MinerUParser 5 | 6 | 7 | def test_check_bin(): 8 | """Ensure mineru CLI is available.""" 9 | MinerUParser() 10 | 11 | 12 | def test_parse_pdf(): 13 | """Parse a real PDF and verify basic structure.""" 14 | repo_root = Path(__file__).resolve().parents[4] 15 | 16 | sample_pdf = os.path.join(repo_root, "resources", "input_examples", "pdf_demo.pdf") 17 | parser = MinerUParser() 18 | blocks = parser.parse_pdf(sample_pdf, device="cpu", method="auto") 19 | 20 | assert isinstance(blocks, list) 21 | assert blocks, "At least one block expected" 22 | 23 | text_blocks = [b for b in blocks if b.get("type") == "text"] 24 | assert text_blocks, "No text block found" 25 | 26 | first = text_blocks[0] 27 | assert "text" in first 28 | assert isinstance(first["content"], str) 29 | assert first["content"].strip(), "Empty text content" 30 | 31 | 32 | def test_empty_pdf(tmp_path: Path) -> None: 33 | """Gracefully handle blank PDF.""" 34 | empty = tmp_path / "empty.pdf" 35 | empty.write_bytes(b"%PDF-1.4\n%%EOF\n") # syntactically valid, no content 36 | 37 | parser = MinerUParser() 38 | blocks = parser.parse_pdf(empty, device="cpu") 39 | 40 | # Empty list or list with empty text block are both acceptable 41 | assert isinstance(blocks, list) 42 | -------------------------------------------------------------------------------- /graphgen/bases/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | 7 | class BaseTokenizer(ABC): 8 | def __init__(self, model_name: str = "cl100k_base"): 9 | self.model_name = model_name 10 | 11 | @abstractmethod 12 | def encode(self, text: str) -> List[int]: 13 | """Encode text -> token ids.""" 14 | raise NotImplementedError 15 | 16 | @abstractmethod 17 | def decode(self, token_ids: List[int]) -> str: 18 | """Decode token ids -> text.""" 19 | raise NotImplementedError 20 | 21 | def count_tokens(self, text: str) -> int: 22 | return len(self.encode(text)) 23 | 24 | def chunk_by_token_size( 25 | self, 26 | content: str, 27 | *, 28 | overlap_token_size: int = 128, 29 | max_token_size: int = 1024, 30 | ) -> List[dict]: 31 | tokens = self.encode(content) 32 | results = [] 33 | step = max_token_size - overlap_token_size 34 | for index, start in enumerate(range(0, len(tokens), step)): 35 | chunk_ids = tokens[start : start + max_token_size] 36 | results.append( 37 | { 38 | "tokens": len(chunk_ids), 39 | "content": self.decode(chunk_ids).strip(), 40 | "chunk_order_index": index, 41 | } 42 | ) 43 | return results 44 | -------------------------------------------------------------------------------- /graphgen/models/searcher/kg/wiki_search.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import wikipedia 4 | from wikipedia import set_lang 5 | 6 | from graphgen.utils import detect_main_language, logger 7 | 8 | 9 | class WikiSearch: 10 | @staticmethod 11 | def set_language(language: str): 12 | assert language in ["en", "zh"], "Only support English and Chinese" 13 | set_lang(language) 14 | 15 | async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]: 16 | self.set_language(detect_main_language(query)) 17 | return wikipedia.search(query, results=num_results, suggestion=False) 18 | 19 | async def summary(self, query: str) -> Union[str, None]: 20 | self.set_language(detect_main_language(query)) 21 | try: 22 | result = wikipedia.summary(query, auto_suggest=False, redirect=False) 23 | except wikipedia.exceptions.DisambiguationError as e: 24 | logger.error("DisambiguationError: %s", e) 25 | result = None 26 | return result 27 | 28 | async def page(self, query: str) -> Union[str, None]: 29 | self.set_language(detect_main_language(query)) 30 | try: 31 | result = wikipedia.page(query, auto_suggest=False, redirect=False).content 32 | except wikipedia.exceptions.DisambiguationError as e: 33 | logger.error("DisambiguationError: %s", e) 34 | result = None 35 | return result 36 | -------------------------------------------------------------------------------- /graphgen/bases/base_kg_builder.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import defaultdict 3 | from typing import Dict, List, Tuple 4 | 5 | from graphgen.bases.base_llm_wrapper import BaseLLMWrapper 6 | from graphgen.bases.base_storage import BaseGraphStorage 7 | from graphgen.bases.datatypes import Chunk 8 | 9 | 10 | class BaseKGBuilder(ABC): 11 | def __init__(self, llm_client: BaseLLMWrapper): 12 | self.llm_client = llm_client 13 | self._nodes: Dict[str, List[dict]] = defaultdict(list) 14 | self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list) 15 | 16 | @abstractmethod 17 | async def extract( 18 | self, chunk: Chunk 19 | ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: 20 | """Extract nodes and edges from a single chunk.""" 21 | raise NotImplementedError 22 | 23 | @abstractmethod 24 | async def merge_nodes( 25 | self, 26 | node_data: tuple[str, List[dict]], 27 | kg_instance: BaseGraphStorage, 28 | ) -> None: 29 | """Merge extracted nodes into the knowledge graph.""" 30 | raise NotImplementedError 31 | 32 | @abstractmethod 33 | async def merge_edges( 34 | self, 35 | edges_data: tuple[Tuple[str, str], List[dict]], 36 | kg_instance: BaseGraphStorage, 37 | ) -> None: 38 | """Merge extracted edges into the knowledge graph.""" 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /examples/generate/generate_multi_hop_qa/multi_hop_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | graph_backend: kuzu # graph database backend, support: kuzu, networkx 4 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 5 | 6 | nodes: 7 | - id: read 8 | op_name: read 9 | type: source 10 | dependencies: [] 11 | params: 12 | input_path: 13 | - examples/input_examples/csv_demo.csv 14 | 15 | - id: chunk 16 | op_name: chunk 17 | type: map_batch 18 | dependencies: 19 | - read 20 | execution_params: 21 | replicas: 4 22 | params: 23 | chunk_size: 1024 24 | chunk_overlap: 100 25 | 26 | - id: build_kg 27 | op_name: build_kg 28 | type: map_batch 29 | dependencies: 30 | - chunk 31 | execution_params: 32 | replicas: 1 33 | batch_size: 128 34 | 35 | - id: partition 36 | op_name: partition 37 | type: aggregate 38 | dependencies: 39 | - build_kg 40 | params: 41 | method: ece 42 | method_params: 43 | max_units_per_community: 3 44 | min_units_per_community: 3 45 | max_tokens_per_community: 10240 46 | unit_sampling: random 47 | 48 | - id: generate 49 | op_name: generate 50 | type: map_batch 51 | dependencies: 52 | - partition 53 | execution_params: 54 | replicas: 1 55 | batch_size: 128 56 | params: 57 | method: multi_hop 58 | data_format: ChatML 59 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a slim version of Python 3.10 as the base image 2 | FROM python:3.10-slim 3 | 4 | # Set environment variables to prevent Python from writing .pyc files 5 | # and to ensure output is logged directly 6 | ENV PYTHONDONTWRITEBYTECODE=1 7 | ENV PYTHONUNBUFFERED=1 8 | 9 | # Set the working directory inside the container 10 | WORKDIR /app 11 | 12 | # Install system dependencies required for building Python packages 13 | RUN apt-get update && \ 14 | apt-get install -y --no-install-recommends \ 15 | git \ 16 | build-essential \ 17 | && apt-get clean && \ 18 | rm -rf /var/lib/apt/lists/* 19 | 20 | # Create a non-root user and switch to that user 21 | RUN useradd -m appuser 22 | 23 | # Copy requirements file and install Python dependencies 24 | COPY requirements.txt . 25 | RUN pip install --upgrade pip && \ 26 | pip install --no-cache-dir -r requirements.txt 27 | 28 | # Copy the rest of the application code into the container 29 | COPY . . 30 | 31 | # Create necessary directories 32 | RUN mkdir -p cache/data/graphgen cache/logs 33 | 34 | # Environment variables for application config 35 | ENV SYNTHESIZER_MODEL="" 36 | ENV SYNTHESIZER_BASE_URL="" 37 | ENV SYNTHESIZER_API_KEY="" 38 | ENV TRAINEE_MODEL="" 39 | ENV TRAINEE_BASE_URL="" 40 | ENV TRAINEE_API_KEY="" 41 | 42 | # Expose the port the app will run on 43 | EXPOSE 7860 44 | 45 | # Switch to the non-root user 46 | USER appuser 47 | 48 | # Command to run the application 49 | CMD ["python", "webui/app.py"] 50 | -------------------------------------------------------------------------------- /tests/integration_tests/models/llm/local/test_hf_wrapper.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | import pytest 4 | 5 | from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def mock_hf(monkeypatch): 10 | mock_tokenizer = MagicMock() 11 | mock_tokenizer.pad_token = None 12 | mock_tokenizer.eos_token = "" 13 | mock_tokenizer.eos_token_id = 0 14 | mock_tokenizer.decode.return_value = "hello" 15 | mock_tokenizer.encode.return_value = [1, 2, 3] 16 | monkeypatch.setattr( 17 | "graphgen.models.llm.local.hf_wrapper.AutoTokenizer.from_pretrained", 18 | lambda *a, **kw: mock_tokenizer, 19 | ) 20 | 21 | mock_model = MagicMock() 22 | mock_model.device = "cpu" 23 | mock_model.generate.return_value = MagicMock( 24 | __getitem__=lambda s, k: [0, 1, 2, 3], shape=(1, 4) 25 | ) 26 | mock_model.eval.return_value = None 27 | monkeypatch.setattr( 28 | "graphgen.models.llm.local.hf_wrapper.AutoModelForCausalLM.from_pretrained", 29 | lambda *a, **kw: mock_model, 30 | ) 31 | 32 | monkeypatch.setattr( 33 | "graphgen.models.llm.local.hf_wrapper.torch.no_grad", MagicMock() 34 | ) 35 | 36 | return mock_tokenizer, mock_model 37 | 38 | 39 | @pytest.mark.asyncio 40 | async def test_generate_answer(): 41 | wrapper = HuggingFaceWrapper("fake-model") 42 | result = await wrapper.generate_answer("hi") 43 | assert isinstance(result, str) 44 | -------------------------------------------------------------------------------- /graphgen/models/vis/community_visualizer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict 3 | 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | 7 | 8 | @dataclass 9 | class Visualizer: 10 | """ 11 | Class for visualizing graphs using NetworkX and Matplotlib. 12 | """ 13 | 14 | graph: nx.Graph = None 15 | communities: Dict[str, int] = None 16 | layout: str = "spring" 17 | max_nodes: int = 1000 18 | node_size: int = 10 19 | alpha: float = 0.6 20 | 21 | def visualize(self, save_path: str = None): 22 | n = self.graph.number_of_nodes() 23 | if self.layout == "spring": 24 | k = max(0.1, 1.0 / (n**0.5)) 25 | pos = nx.spring_layout(self.graph, k=k, seed=42) 26 | else: 27 | raise ValueError(f"Unknown layout: {self.layout}") 28 | 29 | plt.figure(figsize=(10, 10)) 30 | 31 | node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()] 32 | 33 | nx.draw_networkx_nodes( 34 | self.graph, 35 | pos, 36 | node_size=self.node_size, 37 | node_color=node_colors, 38 | cmap=plt.cm.tab20, 39 | alpha=self.alpha, 40 | ) 41 | nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2) 42 | plt.axis("off") 43 | 44 | if save_path: 45 | plt.savefig(save_path, dpi=300, bbox_inches="tight") 46 | print("Saved to", save_path) 47 | else: 48 | plt.show() 49 | -------------------------------------------------------------------------------- /graphgen/models/searcher/web/bing_search.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from fastapi import HTTPException 3 | 4 | from graphgen.utils import logger 5 | 6 | BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" 7 | BING_MKT = "en-US" 8 | 9 | 10 | class BingSearch: 11 | """ 12 | Bing Search client to searcher with Bing. 13 | """ 14 | 15 | def __init__(self, subscription_key: str): 16 | self.subscription_key = subscription_key 17 | 18 | def search(self, query: str, num_results: int = 1): 19 | """ 20 | Search with Bing and return the contexts. 21 | :param query: The searcher query. 22 | :param num_results: The number of results to return. 23 | :return: A list of searcher results. 24 | """ 25 | params = {"q": query, "mkt": BING_MKT, "count": num_results} 26 | response = requests.get( 27 | BING_SEARCH_V7_ENDPOINT, 28 | headers={"Ocp-Apim-Subscription-Key": self.subscription_key}, 29 | params=params, 30 | timeout=10, 31 | ) 32 | if not response.ok: 33 | logger.error("Search engine error: %s", response.text) 34 | raise HTTPException(response.status_code, "Search engine error.") 35 | json_content = response.json() 36 | try: 37 | contexts = json_content["webPages"]["value"][:num_results] 38 | except KeyError: 39 | logger.error("Error encountered: %s", json_content) 40 | return [] 41 | return contexts 42 | -------------------------------------------------------------------------------- /graphgen/utils/device.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import subprocess 3 | import sys 4 | 5 | 6 | def pick_device() -> str: 7 | """Return the best available device string for MinerU.""" 8 | # 1. NVIDIA GPU 9 | if shutil.which("nvidia-smi") is not None: 10 | try: 11 | # check if there's any free GPU memory 12 | out = subprocess.check_output( 13 | [ 14 | "nvidia-smi", 15 | "--query-gpu=memory.free", 16 | "--format=csv,noheader,nounits", 17 | ], 18 | text=True, 19 | ) 20 | if any(int(line) > 0 for line in out.strip().splitlines()): 21 | return "cuda:0" 22 | except Exception: # pylint: disable=broad-except 23 | pass 24 | 25 | # 2. Apple Silicon 26 | if sys.platform == "darwin" and shutil.which("sysctl"): 27 | try: 28 | brand = subprocess.check_output( 29 | ["sysctl", "-n", "machdep.cpu.brand_string"], text=True 30 | ) 31 | if "Apple" in brand: 32 | return "mps" 33 | except Exception: # pylint: disable=broad-except 34 | pass 35 | 36 | # 3. Ascend NPU 37 | if shutil.which("npu-smi") is not None: 38 | try: 39 | subprocess.check_call(["npu-smi", "info"], stdout=subprocess.DEVNULL) 40 | return "npu" 41 | except Exception: # pylint: disable=broad-except 42 | pass 43 | 44 | return "cpu" 45 | -------------------------------------------------------------------------------- /resources/nltk_data/tokenizers/punkt_tab/english/abbrev_types.txt: -------------------------------------------------------------------------------- 1 | ct 2 | m.j 3 | t 4 | a.c 5 | n.h 6 | ms 7 | p.a.m 8 | dr 9 | pa 10 | p.m 11 | u.k 12 | st 13 | dec 14 | u.s.a 15 | lt 16 | g.k 17 | adm 18 | p 19 | h.m 20 | ga 21 | tenn 22 | yr 23 | sen 24 | n.c 25 | j.j 26 | d.h 27 | s.g 28 | inc 29 | vs 30 | s.p.a 31 | a.t 32 | n 33 | feb 34 | sr 35 | jan 36 | s.a.y 37 | n.y 38 | col 39 | g.f 40 | c.o.m.b 41 | d 42 | ft 43 | va 44 | r.k 45 | e.f 46 | chg 47 | r.i 48 | a.g 49 | minn 50 | a.h 51 | k 52 | n.j 53 | m 54 | l.f 55 | f.j 56 | gen 57 | i.m.s 58 | s.a 59 | aug 60 | j.p 61 | okla 62 | m.d.c 63 | ltd 64 | oct 65 | s 66 | vt 67 | r.a 68 | j.c 69 | ariz 70 | w.w 71 | b.v 72 | ore 73 | h 74 | w.r 75 | e.h 76 | mrs 77 | cie 78 | corp 79 | w 80 | n.v 81 | a.d 82 | r.j 83 | ok 84 | . . 85 | e.m 86 | w.c 87 | ill 88 | nov 89 | u.s 90 | prof 91 | conn 92 | u.s.s.r 93 | mg 94 | f.g 95 | ph.d 96 | g 97 | calif 98 | messrs 99 | h.f 100 | wash 101 | tues 102 | sw 103 | bros 104 | u.n 105 | l 106 | wis 107 | mr 108 | sep 109 | d.c 110 | ave 111 | e.l 112 | co 113 | s.s 114 | reps 115 | c 116 | r.t 117 | h.c 118 | r 119 | wed 120 | a.s 121 | v 122 | fla 123 | jr 124 | r.h 125 | c.v 126 | m.b.a 127 | rep 128 | a.a 129 | e 130 | c.i.t 131 | l.a 132 | b.f 133 | j.b 134 | d.w 135 | j.k 136 | ala 137 | f 138 | w.va 139 | sept 140 | mich 141 | n.m 142 | j.r 143 | l.p 144 | s.c 145 | colo 146 | fri 147 | a.m 148 | g.d 149 | kan 150 | maj 151 | ky 152 | a.m.e 153 | n.d 154 | t.j 155 | cos 156 | nev -------------------------------------------------------------------------------- /examples/input_examples/search_protein_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"type": "text", "content": "P01308"} 2 | {"type": "text", "content": "P68871"} 3 | {"type": "text", "content": "P02768"} 4 | {"type": "text", "content": "P04637"} 5 | {"type": "text", "content": "insulin"} 6 | {"type": "text", "content": "hemoglobin"} 7 | {"type": "text", "content": "p53"} 8 | {"type": "text", "content": "BRCA1"} 9 | {"type": "text", "content": "albumin"} 10 | {"type": "text", "content": "MHHHHHHSSGVDLGTENLYFQSNAMDFPQQLEACVKQANQALSRFIAPLPFQNTPVVETMQYGALLGGKRLRPFLVYATGHMFGVSTNTLDAPAAAVECIHAYSLIHDDLPAMDDDDLRRGLPTCHVKFGEANAILAGDALQTLAFSILSDANMPEVSDRDRISMISELASASGIAGMCGGQALDLDAEGKHVPLDALERIHRHKTGALIRAAVRLGALSAGDKGRRALPVLDKYAESIGLAFQVQDDILDVVGDTATLGKRQGADQQLGKSTYPALLGLEQARKKARDLIDDARQALKQLAEQSLDTSALEALADYIIQRNK"} 11 | {"type": "text", "content": "MGSSHHHHHHSQDLENLYFQGSMNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRTKRVITTFRTGTWDAYKNLRKKLEQLYNRYKDPQDENKIGIDGIQQFCDDLALDPASISVLIIAWKFRAATQCEFSKQEFMDGMTELGCDSIEKLKAQIPKMEQELKEPGRFKDFYQFTFNFAKNPGQKGLDLEMAIAYWNLVLNGRFKFLDLWNKFLLEHHKRSIPKDTWNLLLDFSTMIADDMSNYDEEGAWPVLIDDFVEFARPQIAGTKSTTV"} 12 | {"type": "text", "content": "MAKREPIHDNSIRTEWEAKIAKLTSVDQATKFIQDFRLAYTSPFRKSYDIDVDYQYIERKIEEKLSVLKTEKLPVADLITKATTGEDAAAVEATWIAKIKAAKSKYEAEAIHIEFRQLYKPPVLPVNVFLRTDAALGTVLMEIRNTDYYGTPLEGLRKERGVKVLHLQA"} 13 | {"type": "text", "content": "MARVTVQDAVEKIGNRFDLVLVAARRARQMQVGGKDPLVPEENDKTTVIALREIEEGLINNQILDVRERQEQQEQEAAELQAVTAIAEGRR"} 14 | {"type": "text", "content": "GSHMLCAISGKVPRRPVLSPKSRTIFEKSLLEQYVKDTGNDPITNEPLSIEEIVEIVPSAQ"} 15 | -------------------------------------------------------------------------------- /.github/workflows/push-to-hf.yml: -------------------------------------------------------------------------------- 1 | name: Push demo branch to Hugging Face 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | ref: 7 | required: false 8 | default: demo 9 | type: string 10 | secrets: 11 | HF_TOKEN: 12 | required: true 13 | 14 | jobs: 15 | push-hf: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | with: 21 | ref: ${{ inputs.ref }} 22 | token: ${{ secrets.GITHUB_TOKEN }} 23 | 24 | - name: Configure Git identity 25 | run: | 26 | git config --global user.email "actions@github.com" 27 | git config --global user.name "github-actions[bot]" 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install huggingface_hub 33 | 34 | - name: Push to Hugging Face 35 | env: 36 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 37 | HF_REPO_TYPE: spaces 38 | HF_REPO_ID: chenzihong/GraphGen 39 | run: | 40 | git config --global credential.helper store 41 | echo "https://user:${HF_TOKEN}@huggingface.co" > ~/.git-credentials 42 | 43 | [[ -d hf-repo ]] && rm -rf hf-repo 44 | git clone https://huggingface.co/${HF_REPO_TYPE}/${HF_REPO_ID} hf-repo 45 | 46 | rsync -a --delete --exclude='.git' --exclude='hf-repo' --exclude='README.md' ./ hf-repo/ 47 | 48 | cd hf-repo 49 | git add . 50 | git diff-index --quiet HEAD || \ 51 | (git commit -m "Auto-sync from ${{ inputs.ref }} at $(date -u)" && git push) -------------------------------------------------------------------------------- /graphgen/models/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from graphgen.bases import BaseTokenizer 4 | 5 | from .tiktoken_tokenizer import TiktokenTokenizer 6 | 7 | try: 8 | from transformers import AutoTokenizer 9 | 10 | _HF_AVAILABLE = True 11 | except ImportError: 12 | _HF_AVAILABLE = False 13 | 14 | 15 | def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer: 16 | import tiktoken 17 | 18 | if tokenizer_name in tiktoken.list_encoding_names(): 19 | return TiktokenTokenizer(model_name=tokenizer_name) 20 | 21 | # 2. HuggingFace 22 | if _HF_AVAILABLE: 23 | from .hf_tokenizer import HFTokenizer 24 | 25 | return HFTokenizer(model_name=tokenizer_name) 26 | 27 | raise ValueError( 28 | f"Unknown tokenizer {tokenizer_name} and HuggingFace not available." 29 | ) 30 | 31 | 32 | class Tokenizer(BaseTokenizer): 33 | """ 34 | Encapsulates different tokenization implementations based on the specified model name. 35 | """ 36 | 37 | def __init__(self, model_name: str = "cl100k_base"): 38 | super().__init__(model_name) 39 | if not self.model_name: 40 | raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.") 41 | self._impl = get_tokenizer_impl(self.model_name) 42 | 43 | def encode(self, text: str) -> List[int]: 44 | return self._impl.encode(text) 45 | 46 | def decode(self, token_ids: List[int]) -> str: 47 | return self._impl.decode(token_ids) 48 | 49 | def count_tokens(self, text: str) -> int: 50 | return self._impl.count_tokens(text) 51 | -------------------------------------------------------------------------------- /graphgen/operators/build_kg/build_text_kg.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List 3 | 4 | from graphgen.bases import BaseLLMWrapper 5 | from graphgen.bases.base_storage import BaseGraphStorage 6 | from graphgen.bases.datatypes import Chunk 7 | from graphgen.models import LightRAGKGBuilder 8 | from graphgen.utils import run_concurrent 9 | 10 | 11 | def build_text_kg( 12 | llm_client: BaseLLMWrapper, 13 | kg_instance: BaseGraphStorage, 14 | chunks: List[Chunk], 15 | ): 16 | """ 17 | :param llm_client: Synthesizer LLM model to extract entities and relationships 18 | :param kg_instance 19 | :param chunks 20 | :return: 21 | """ 22 | 23 | kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) 24 | 25 | results = run_concurrent( 26 | kg_builder.extract, 27 | chunks, 28 | desc="[2/4]Extracting entities and relationships from chunks", 29 | unit="chunk", 30 | ) 31 | 32 | nodes = defaultdict(list) 33 | edges = defaultdict(list) 34 | for n, e in results: 35 | for k, v in n.items(): 36 | nodes[k].extend(v) 37 | for k, v in e.items(): 38 | edges[tuple(sorted(k))].extend(v) 39 | 40 | run_concurrent( 41 | lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance), 42 | list(nodes.items()), 43 | desc="Inserting entities into storage", 44 | ) 45 | 46 | run_concurrent( 47 | lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance), 48 | list(edges.items()), 49 | desc="Inserting relationships into storage", 50 | ) 51 | -------------------------------------------------------------------------------- /graphgen/operators/build_kg/build_mm_kg.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List 3 | 4 | from graphgen.bases import BaseLLMWrapper 5 | from graphgen.bases.base_storage import BaseGraphStorage 6 | from graphgen.bases.datatypes import Chunk 7 | from graphgen.models import MMKGBuilder 8 | from graphgen.utils import run_concurrent 9 | 10 | 11 | def build_mm_kg( 12 | llm_client: BaseLLMWrapper, 13 | kg_instance: BaseGraphStorage, 14 | chunks: List[Chunk], 15 | ): 16 | """ 17 | Build multi-modal KG and merge into kg_instance 18 | :param llm_client: Synthesizer LLM model to extract entities and relationships 19 | :param kg_instance 20 | :param chunks 21 | :return: 22 | """ 23 | mm_builder = MMKGBuilder(llm_client=llm_client) 24 | 25 | results = run_concurrent( 26 | mm_builder.extract, 27 | chunks, 28 | desc="[2/4] Extracting entities and relationships from multi-modal chunks", 29 | unit="chunk", 30 | ) 31 | 32 | nodes = defaultdict(list) 33 | edges = defaultdict(list) 34 | for n, e in results: 35 | for k, v in n.items(): 36 | nodes[k].extend(v) 37 | for k, v in e.items(): 38 | edges[tuple(sorted(k))].extend(v) 39 | 40 | run_concurrent( 41 | lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance), 42 | list(nodes.items()), 43 | desc="Inserting entities into storage", 44 | ) 45 | 46 | run_concurrent( 47 | lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance), 48 | list(edges.items()), 49 | desc="Inserting relationships into storage", 50 | ) 51 | -------------------------------------------------------------------------------- /graphgen/models/searcher/web/google_search.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from fastapi import HTTPException 3 | 4 | from graphgen.utils import logger 5 | 6 | GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1" 7 | 8 | 9 | class GoogleSearch: 10 | def __init__(self, subscription_key: str, cx: str): 11 | """ 12 | Initialize the Google Search client with the subscription key and custom searcher engine ID. 13 | :param subscription_key: Your Google API subscription key. 14 | :param cx: Your custom searcher engine ID. 15 | """ 16 | self.subscription_key = subscription_key 17 | self.cx = cx 18 | 19 | def search(self, query: str, num_results: int = 1): 20 | """ 21 | Search with Google and return the contexts. 22 | :param query: The searcher query. 23 | :param num_results: The number of results to return. 24 | :return: A list of searcher results. 25 | """ 26 | params = { 27 | "key": self.subscription_key, 28 | "cx": self.cx, 29 | "q": query, 30 | "num": num_results, 31 | } 32 | response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10) 33 | if not response.ok: 34 | logger.error("Search engine error: %s", response.text) 35 | raise HTTPException(response.status_code, "Search engine error.") 36 | json_content = response.json() 37 | try: 38 | contexts = json_content["items"][:num_results] 39 | except KeyError: 40 | logger.error("Error encountered: %s", json_content) 41 | return [] 42 | return contexts 43 | -------------------------------------------------------------------------------- /graphgen/templates/extraction/schemas/legal_contract.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "name": "legal contract", 4 | "description": "A legal contract for leasing property.", 5 | "properties": { 6 | "leased_space": { 7 | "type": "string", 8 | "description": "Description of the space that is being leased." 9 | }, 10 | "lessee": { 11 | "type": "string", 12 | "description": "The lessee's name (and possibly address)." 13 | }, 14 | "lessor": { 15 | "type": "string", 16 | "description": "The lessor's name (and possibly address)." 17 | }, 18 | "signing_date": { 19 | "type": "string", 20 | "description": "The date the contract was signed." 21 | }, 22 | "start_date": { 23 | "type": "string", 24 | "description": "The start date of the lease." 25 | }, 26 | "end_date": { 27 | "type": "string", 28 | "description": "The end date of the lease." 29 | }, 30 | "term_of_payment": { 31 | "type": "string", 32 | "description": "Description of the payment terms." 33 | }, 34 | "designated_use": { 35 | "type": "string", 36 | "description": "Designated use of the property being leased." 37 | }, 38 | "extension_period": { 39 | "type": "string", 40 | "description": "Description of the extension options for the lease." 41 | }, 42 | "expiration_date_of_lease": { 43 | "type": "string", 44 | "description": "The expiration date of the lease." 45 | } 46 | }, 47 | "required": ["lessee", "lessor", "start_date", "end_date"] 48 | } 49 | -------------------------------------------------------------------------------- /.github/workflows/push-to-ms.yml: -------------------------------------------------------------------------------- 1 | name: Push demo branch to ModelScope 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | ref: 7 | required: false 8 | default: demo 9 | type: string 10 | secrets: 11 | MS_TOKEN: 12 | required: true 13 | 14 | jobs: 15 | push-ms: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | with: 21 | ref: ${{ inputs.ref }} 22 | token: ${{ secrets.GITHUB_TOKEN }} 23 | 24 | - name: Configure Git identity 25 | run: | 26 | git config --global user.email "actions@github.com" 27 | git config --global user.name "github-actions[bot]" 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | # ModelScope official SDK (optional, install only if you need to call the platform API) 33 | pip install modelscope 34 | 35 | - name: Push to ModelScope 36 | env: 37 | MS_TOKEN: ${{ secrets.MS_TOKEN }} 38 | MS_REPO_TYPE: studios 39 | MS_REPO_ID: chenzihong/GraphGen 40 | run: | 41 | [[ -d ms-repo ]] && rm -rf ms-repo 42 | git clone https://oauth2:${MS_TOKEN}@www.modelscope.cn/${MS_REPO_TYPE}/${MS_REPO_ID}.git ms-repo 43 | 44 | rsync -a --delete --exclude='.git' --exclude='ms-repo' --exclude='README.md' ./ ms-repo/ 45 | 46 | cd ms-repo 47 | git add . 48 | git diff-index --quiet HEAD || \ 49 | (git commit -m "Auto-sync from ${{ inputs.ref }} at $(date -u)" && \ 50 | git push "https://oauth2:${MS_TOKEN}@www.modelscope.cn/${MS_REPO_TYPE}/${MS_REPO_ID}.git") 51 | -------------------------------------------------------------------------------- /tests/integration_tests/models/splitter/test_recursive_character_splitter.py: -------------------------------------------------------------------------------- 1 | from graphgen.models.splitter.recursive_character_splitter import ( 2 | ChineseRecursiveTextSplitter, 3 | RecursiveCharacterSplitter, 4 | ) 5 | 6 | 7 | def test_split_english_paragraph(): 8 | text = ( 9 | "Natural language processing (NLP) is a subfield of linguistics, computer science, " 10 | "and artificial intelligence. It focuses on the interaction between computers and " 11 | "humans through natural language. The ultimate objective of NLP is to read, decipher, " 12 | "understand, and make sense of human languages in a manner that is valuable.\n\n" 13 | "Most NLP techniques rely on machine learning." 14 | ) 15 | 16 | splitter = RecursiveCharacterSplitter( 17 | chunk_size=150, 18 | chunk_overlap=0, 19 | keep_separator=True, 20 | is_separator_regex=False, 21 | ) 22 | chunks = splitter.split_text(text) 23 | 24 | assert len(chunks) > 1 25 | for chk in chunks: 26 | assert len(chk) <= 150 27 | 28 | 29 | def test_split_chinese_with_punctuation(): 30 | text = ( 31 | "自然语言处理是人工智能的重要分支。它研究能实现人与计算机之间用自然语言" 32 | "进行有效通信的各种理论和方法!融合语言学、计算机科学、数学于一体?" 33 | "近年来,深度学习极大推动了NLP的发展;Transformer、BERT、GPT等模型层出不穷," 34 | ",,,甚至出现了多模态大模型。\n\n" 35 | "未来,NLP 将继续向通用人工智能迈进。" 36 | ) 37 | 38 | splitter = ChineseRecursiveTextSplitter( 39 | chunk_size=60, 40 | chunk_overlap=0, 41 | keep_separator=True, 42 | is_separator_regex=True, 43 | ) 44 | chunks = splitter.split_text(text) 45 | 46 | assert len(chunks) > 1 47 | for chk in chunks: 48 | assert len(chk) <= 60 49 | assert "\n\n\n" not in chk 50 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | TOKENIZER_MODEL= 3 | 4 | # LLM 5 | # Support different backends: http_api, openai_api, ollama_api, ollama, huggingface, tgi, sglang, tensorrt 6 | 7 | # http_api / openai_api 8 | SYNTHESIZER_BACKEND=openai_api 9 | SYNTHESIZER_MODEL=gpt-4o-mini 10 | SYNTHESIZER_BASE_URL= 11 | SYNTHESIZER_API_KEY= 12 | TRAINEE_BACKEND=openai_api 13 | TRAINEE_MODEL=gpt-4o-mini 14 | TRAINEE_BASE_URL= 15 | TRAINEE_API_KEY= 16 | 17 | # azure_openai_api 18 | # SYNTHESIZER_BACKEND=azure_openai_api 19 | # The following is the same as your "Deployment name" in Azure 20 | # SYNTHESIZER_MODEL= 21 | # SYNTHESIZER_BASE_URL=https://.openai.azure.com/openai/deployments//chat/completions 22 | # SYNTHESIZER_API_KEY= 23 | # SYNTHESIZER_API_VERSION= 24 | 25 | # # ollama_api 26 | # SYNTHESIZER_BACKEND=ollama_api 27 | # SYNTHESIZER_MODEL=gemma3 28 | # SYNTHESIZER_BASE_URL=http://localhost:11434 29 | # 30 | # Note: TRAINEE with ollama_api backend is not supported yet as ollama_api does not support logprobs. 31 | 32 | # # huggingface 33 | # SYNTHESIZER_BACKEND=huggingface 34 | # SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct 35 | # 36 | # TRAINEE_BACKEND=huggingface 37 | # TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct 38 | 39 | # # sglang 40 | # SYNTHESIZER_BACKEND=sglang 41 | # SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct 42 | # SYNTHESIZER_TP_SIZE=1 43 | # SYNTHESIZER_NUM_GPUS=1 44 | 45 | # TRAINEE_BACKEND=sglang 46 | # TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct 47 | # SYNTHESIZER_TP_SIZE=1 48 | # SYNTHESIZER_NUM_GPUS=1 49 | 50 | # # vllm 51 | # SYNTHESIZER_BACKEND=vllm 52 | # SYNTHESIZER_MODEL=Qwen/Qwen2.5-0.5B-Instruct 53 | # SYNTHESIZER_NUM_GPUS=1 54 | 55 | # TRAINEE_BACKEND=vllm 56 | # TRAINEE_MODEL=Qwen/Qwen2.5-0.5B-Instruct 57 | # TRAINEE_NUM_GPUS=1 58 | -------------------------------------------------------------------------------- /graphgen/templates/generation/multi_hop_generation.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。你的任务是提出一个问题,该问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。 3 | 4 | 例如: 5 | ######## 6 | --实体-- 7 | 1. 苹果 8 | 2. 水果 9 | 3. 维生素C 10 | ######## 11 | --关系-- 12 | 1. 苹果-水果:苹果是一种水果 13 | 2. 水果-维生素C:水果中富含维生素C 14 | ######## 15 | 问题:通过吃苹果补充的什么物质,有助于维持健康? 16 | 答案:维生素C 17 | ######## 18 | 19 | ######### 20 | --实体-- 21 | {entities} 22 | ######### 23 | --关系-- 24 | {relationships} 25 | ######### 26 | 直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。 27 | """ 28 | 29 | TEMPLATE_EN: str = """Please generate a multi-hop reasoning question and answer based on the following knowledge subgraph. You will be provided with a knowledge subgraph that contains a series of entities, relations, and facts. Your task is to generate a question that requires multiple steps of reasoning to answer. The answer to the question should be inferred from the given knowledge subgraph. Ensure that the question is of moderate difficulty and requires multiple steps of reasoning to answer. 30 | 31 | For example: 32 | ######## 33 | --Entities-- 34 | 1. Apple 35 | 2. Fruit 36 | 3. Vitamin C 37 | ######## 38 | --Relations-- 39 | 1. Apple-Fruit: Apple is a type of fruit 40 | 2. Fruit-Vitamin C: Fruits are rich in Vitamin C 41 | ######## 42 | Question: What substance, obtained through eating apples, helps maintain health? 43 | Answer: Vitamin C 44 | ######## 45 | 46 | ######## 47 | --Entities-- 48 | {entities} 49 | ######## 50 | --Relations-- 51 | {relationships} 52 | ######## 53 | Output the generated question and answer directly, please do not copy the example question and answer directly, and do not provide irrelevant information. 54 | """ 55 | 56 | MULTI_HOP_GENERATION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH} 57 | -------------------------------------------------------------------------------- /graphgen/operators/extract/extract_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | 5 | from graphgen.bases import BaseLLMWrapper, BaseOperator 6 | from graphgen.common import init_llm 7 | from graphgen.models.extractor import SchemaGuidedExtractor 8 | from graphgen.utils import logger, run_concurrent 9 | 10 | 11 | class ExtractService(BaseOperator): 12 | def __init__(self, working_dir: str = "cache", **extract_kwargs): 13 | super().__init__(working_dir=working_dir, op_name="extract_service") 14 | self.llm_client: BaseLLMWrapper = init_llm("synthesizer") 15 | self.extract_kwargs = extract_kwargs 16 | self.method = self.extract_kwargs.get("method") 17 | if self.method == "schema_guided": 18 | schema_file = self.extract_kwargs.get("schema_path") 19 | with open(schema_file, "r", encoding="utf-8") as f: 20 | schema = json.load(f) 21 | self.extractor = SchemaGuidedExtractor(self.llm_client, schema) 22 | else: 23 | raise ValueError(f"Unsupported extraction method: {self.method}") 24 | 25 | def process(self, batch: pd.DataFrame) -> pd.DataFrame: 26 | items = batch.to_dict(orient="records") 27 | return pd.DataFrame(self.extract(items)) 28 | 29 | def extract(self, items: list[dict]) -> list[dict]: 30 | 31 | logger.info("Start extracting information from %d items", len(items)) 32 | 33 | results = run_concurrent( 34 | self.extractor.extract, 35 | items, 36 | desc="Extracting information", 37 | unit="item", 38 | ) 39 | results = self.extractor.merge_extractions(results) 40 | 41 | results = [ 42 | {"_extract_id": key, "extracted_data": value} 43 | for key, value in results.items() 44 | ] 45 | return results 46 | -------------------------------------------------------------------------------- /graphgen/models/evaluator/base_evaluator.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from tqdm.asyncio import tqdm as tqdm_async 4 | 5 | from graphgen.bases.datatypes import QAPair 6 | from graphgen.utils import create_event_loop 7 | 8 | 9 | class BaseEvaluator: 10 | def __init__(self, max_concurrent: int = 100): 11 | self.max_concurrent = max_concurrent 12 | self.results: list[float] = None 13 | 14 | def evaluate(self, pairs: list[QAPair]) -> list[float]: 15 | """ 16 | Evaluate the text and return a score. 17 | """ 18 | return create_event_loop().run_until_complete(self.async_evaluate(pairs)) 19 | 20 | async def async_evaluate(self, pairs: list[QAPair]) -> list[float]: 21 | semaphore = asyncio.Semaphore(self.max_concurrent) 22 | 23 | async def evaluate_with_semaphore(pair): 24 | async with semaphore: # 获取Semaphore 25 | return await self.evaluate_single(pair) 26 | 27 | results = [] 28 | for result in tqdm_async( 29 | asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]), 30 | total=len(pairs), 31 | ): 32 | results.append(await result) 33 | return results 34 | 35 | async def evaluate_single(self, pair: QAPair) -> float: 36 | raise NotImplementedError() 37 | 38 | def get_average_score(self, pairs: list[QAPair]) -> float: 39 | """ 40 | Get the average score of a batch of texts. 41 | """ 42 | results = self.evaluate(pairs) 43 | self.results = results 44 | return sum(self.results) / len(pairs) 45 | 46 | def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: 47 | """ 48 | Get the min and max score of a batch of texts. 49 | """ 50 | if self.results is None: 51 | self.get_average_score(pairs) 52 | return min(self.results), max(self.results) 53 | -------------------------------------------------------------------------------- /graphgen/templates/extraction/schema_guided_extraction.py: -------------------------------------------------------------------------------- 1 | TEMPLATE_EN = """You are an expert at extracting information from text based on a given schema. 2 | Extract relevant information about {field} from a given contract document according to the provided schema. 3 | 4 | Instructions: 5 | 1. Carefully read the entire document provided at the end of this prompt. 6 | 2. Extract the relevant information. 7 | 3. Present your findings in JSON format as specified below. 8 | 9 | Important Notes: 10 | - Extract only relevant information. 11 | - Consider the context of the entire document when determining relevance. 12 | - Do not be verbose, only respond with the correct format and information. 13 | - Some docs may have multiple relevant excerpts -- include all that apply. 14 | - Some questions may have no relevant excerpts -- just return "". 15 | - Do not include additional JSON keys beyond the ones listed here. 16 | - Do not include the same key multiple times in the JSON. 17 | - Use English for your response. 18 | 19 | Expected JSON keys and explanation of what they are: 20 | {schema_explanation} 21 | 22 | Expected format: 23 | {{ 24 | "key1": "value1", 25 | "key2": "value2", 26 | ... 27 | }} 28 | 29 | {examples} 30 | 31 | Document to extract from: 32 | {text} 33 | """ 34 | 35 | TEMPLATE_ZH = """你是一个擅长根据给定的模式从文本中提取信息的专家。 36 | 根据提供的模式,从合同文件中提取与{field}相关的信息。 37 | 操作说明: 38 | 1. 仔细阅读本提示末尾提供的整份文件。 39 | 2. 提取相关信息。 40 | 3. 按照下面指定的JSON格式呈现你的发现。 41 | 42 | 重要注意事项: 43 | - 仅提取相关信息。 44 | - 在确定相关性时,考虑整份文件的上下文。 45 | - 不要冗长,只需以正确的格式和信息进行回应。 46 | - 有些文件可能有多个相关摘录——请包含所有适用的内容。 47 | - 有些问题可能没有相关摘录——只需返回""。 48 | - 不要在JSON中包含除列出的键之外的其他键。 49 | - 不要多次包含同一个键。 50 | - 使用中文回答。 51 | 52 | 预期的JSON键及其说明: 53 | {schema_explanation} 54 | 55 | 预期格式: 56 | {{ 57 | "key1": "value1", 58 | "key2": "value2", 59 | ... 60 | }} 61 | 62 | {examples} 63 | 要提取的文件: 64 | {text} 65 | """ 66 | 67 | SCHEMA_GUIDED_EXTRACTION_PROMPT = { 68 | "en": TEMPLATE_EN, 69 | "zh": TEMPLATE_ZH, 70 | } 71 | -------------------------------------------------------------------------------- /examples/output_examples/atomic_alpaca.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "What do regulatory mechanisms control in plants?", 4 | "input": "", 5 | "output": "Regulatory mechanisms control the development and function of plant traits, such as tiller angle." 6 | }, 7 | { 8 | "instruction": "What is the primary focus of genetic studies on rice?", 9 | "input": "", 10 | "output": "The primary focus of genetic studies on rice is to improve grain yield and size." 11 | }, 12 | { 13 | "instruction": "What is the function of the cell wall in plant cells?", 14 | "input": "", 15 | "output": "The cell wall provides structural support and protection to plant cells." 16 | }, 17 | { 18 | "instruction": "What is the result of overexpressing a gene like BG1?", 19 | "input": "", 20 | "output": "Overexpressing a gene like BG1 leads to enhanced traits such as larger grain size." 21 | }, 22 | { 23 | "instruction": "云粳26号是由哪个机构育成的?", 24 | "input": "", 25 | "output": "云粳26号是由云南省农业科学院育成的。" 26 | }, 27 | { 28 | "instruction": "农业推广的定义是什么?", 29 | "input": "", 30 | "output": "农业推广是指通过政策和技术手段,促进优质农作物的种植和应用。" 31 | }, 32 | { 33 | "instruction": "What is the purpose of the Red Fluorescent Protein (RFP) in cellular studies?", 34 | "input": "", 35 | "output": "The Red Fluorescent Protein (RFP) is used to visualize protein localization in cells by emitting fluorescent signals." 36 | }, 37 | { 38 | "instruction": "隆两优1212的产量是多少?", 39 | "input": "", 40 | "output": "隆两优1212的产量为8.78 t/hm^2。" 41 | }, 42 | { 43 | "instruction": "云粳26号适宜种植的海拔范围是多少?", 44 | "input": "", 45 | "output": "云粳26号适宜种植的海拔范围是1 500∼1 800 m。" 46 | }, 47 | { 48 | "instruction": "云粳26号的食味品质中有哪些特点?", 49 | "input": "", 50 | "output": "有香味是云粳26号的食味品质之一,增加了其市场吸引力。" 51 | } 52 | ] -------------------------------------------------------------------------------- /resources/nltk_data/corpora/stopwords/english: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | you're 11 | you've 12 | you'll 13 | you'd 14 | your 15 | yours 16 | yourself 17 | yourselves 18 | he 19 | him 20 | his 21 | himself 22 | she 23 | she's 24 | her 25 | hers 26 | herself 27 | it 28 | it's 29 | its 30 | itself 31 | they 32 | them 33 | their 34 | theirs 35 | themselves 36 | what 37 | which 38 | who 39 | whom 40 | this 41 | that 42 | that'll 43 | these 44 | those 45 | am 46 | is 47 | are 48 | was 49 | were 50 | be 51 | been 52 | being 53 | have 54 | has 55 | had 56 | having 57 | do 58 | does 59 | did 60 | doing 61 | a 62 | an 63 | the 64 | and 65 | but 66 | if 67 | or 68 | because 69 | as 70 | until 71 | while 72 | of 73 | at 74 | by 75 | for 76 | with 77 | about 78 | against 79 | between 80 | into 81 | through 82 | during 83 | before 84 | after 85 | above 86 | below 87 | to 88 | from 89 | up 90 | down 91 | in 92 | out 93 | on 94 | off 95 | over 96 | under 97 | again 98 | further 99 | then 100 | once 101 | here 102 | there 103 | when 104 | where 105 | why 106 | how 107 | all 108 | any 109 | both 110 | each 111 | few 112 | more 113 | most 114 | other 115 | some 116 | such 117 | no 118 | nor 119 | not 120 | only 121 | own 122 | same 123 | so 124 | than 125 | too 126 | very 127 | s 128 | t 129 | can 130 | will 131 | just 132 | don 133 | don't 134 | should 135 | should've 136 | now 137 | d 138 | ll 139 | m 140 | o 141 | re 142 | ve 143 | y 144 | ain 145 | aren 146 | aren't 147 | couldn 148 | couldn't 149 | didn 150 | didn't 151 | doesn 152 | doesn't 153 | hadn 154 | hadn't 155 | hasn 156 | hasn't 157 | haven 158 | haven't 159 | isn 160 | isn't 161 | ma 162 | mightn 163 | mightn't 164 | mustn 165 | mustn't 166 | needn 167 | needn't 168 | shan 169 | shan't 170 | shouldn 171 | shouldn't 172 | wasn 173 | wasn't 174 | weren 175 | weren't 176 | won 177 | won't 178 | wouldn 179 | wouldn't 180 | -------------------------------------------------------------------------------- /graphgen/models/storage/kv/json_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | from graphgen.bases.base_storage import BaseKVStorage 5 | from graphgen.utils import load_json, write_json 6 | 7 | 8 | @dataclass 9 | class JsonKVStorage(BaseKVStorage): 10 | _data: dict[str, dict] = None 11 | 12 | def __post_init__(self): 13 | self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") 14 | self._data = load_json(self._file_name) or {} 15 | print(f"Load KV {self.namespace} with {len(self._data)} data") 16 | 17 | @property 18 | def data(self): 19 | return self._data 20 | 21 | def all_keys(self) -> list[str]: 22 | return list(self._data.keys()) 23 | 24 | def index_done_callback(self): 25 | write_json(self._data, self._file_name) 26 | 27 | def get_by_id(self, id): 28 | return self._data.get(id, None) 29 | 30 | def get_by_ids(self, ids, fields=None) -> list: 31 | if fields is None: 32 | return [self._data.get(id, None) for id in ids] 33 | return [ 34 | ( 35 | {k: v for k, v in self._data[id].items() if k in fields} 36 | if self._data.get(id, None) 37 | else None 38 | ) 39 | for id in ids 40 | ] 41 | 42 | def get_all(self) -> dict[str, dict]: 43 | return self._data 44 | 45 | def filter_keys(self, data: list[str]) -> set[str]: 46 | return {s for s in data if s not in self._data} 47 | 48 | def upsert(self, data: dict): 49 | left_data = {k: v for k, v in data.items() if k not in self._data} 50 | if left_data: 51 | self._data.update(left_data) 52 | return left_data 53 | 54 | def drop(self): 55 | if self._data: 56 | self._data.clear() 57 | 58 | def reload(self): 59 | self._data = load_json(self._file_name) or {} 60 | print(f"Reload KV {self.namespace} with {len(self._data)} data") 61 | -------------------------------------------------------------------------------- /graphgen/bases/base_operator.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | from abc import ABC, abstractmethod 4 | from typing import Iterable, Union 5 | 6 | import pandas as pd 7 | import ray 8 | 9 | from graphgen.utils import CURRENT_LOGGER_VAR, set_logger 10 | 11 | 12 | class BaseOperator(ABC): 13 | def __init__(self, working_dir: str = "cache", op_name: str = None): 14 | log_dir = os.path.join(working_dir, "logs") 15 | self.op_name = op_name or self.__class__.__name__ 16 | 17 | try: 18 | ctx = ray.get_runtime_context() 19 | worker_id = ctx.get_actor_id() or ctx.get_worker_id() 20 | worker_id_short = worker_id[-6:] if worker_id else "driver" 21 | except Exception as e: 22 | print( 23 | "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:", 24 | e, 25 | ) 26 | worker_id_short = "local" 27 | 28 | # e.g. cache/logs/ChunkService_a1b2c3.log 29 | log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log") 30 | 31 | self.logger = set_logger( 32 | log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True 33 | ) 34 | 35 | self.logger.info( 36 | "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short 37 | ) 38 | 39 | def __call__( 40 | self, batch: pd.DataFrame 41 | ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: 42 | logger_token = CURRENT_LOGGER_VAR.set(self.logger) 43 | try: 44 | result = self.process(batch) 45 | if inspect.isgenerator(result): 46 | yield from result 47 | else: 48 | yield result 49 | finally: 50 | CURRENT_LOGGER_VAR.reset(logger_token) 51 | 52 | @abstractmethod 53 | def process(self, batch): 54 | raise NotImplementedError("Subclasses must implement the process method.") 55 | 56 | def get_logger(self): 57 | return self.logger 58 | -------------------------------------------------------------------------------- /webui/utils/count_tokens.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | import pandas as pd 6 | 7 | # pylint: disable=wrong-import-position 8 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | sys.path.append(root_dir) 10 | from graphgen.models import Tokenizer 11 | 12 | 13 | def count_tokens(file, tokenizer_name, data_frame): 14 | if not file or not os.path.exists(file): 15 | return data_frame 16 | 17 | if file.endswith(".jsonl"): 18 | with open(file, "r", encoding="utf-8") as f: 19 | data = [json.loads(line) for line in f] 20 | elif file.endswith(".json"): 21 | with open(file, "r", encoding="utf-8") as f: 22 | data = json.load(f) 23 | data = [item for sublist in data for item in sublist] 24 | elif file.endswith(".txt"): 25 | with open(file, "r", encoding="utf-8") as f: 26 | data = f.read() 27 | chunks = [data[i : i + 512] for i in range(0, len(data), 512)] 28 | data = [{"content": chunk} for chunk in chunks] 29 | elif file.endswith(".csv"): 30 | df = pd.read_csv(file) 31 | if "content" in df.columns: 32 | data = df["content"].tolist() 33 | else: 34 | data = df.iloc[:, 0].tolist() 35 | else: 36 | raise ValueError(f"Unsupported file type: {file}") 37 | 38 | tokenizer = Tokenizer(tokenizer_name) 39 | 40 | # Count tokens 41 | token_count = 0 42 | 43 | for item in data: 44 | if isinstance(item, dict): 45 | content = item.get("content", "") 46 | else: 47 | content = item 48 | token_count += len(tokenizer.encode(content)) 49 | 50 | _update_data = [[str(token_count), str(token_count * 50), "N/A"]] 51 | 52 | try: 53 | new_df = pd.DataFrame(_update_data, columns=data_frame.columns) 54 | data_frame = new_df 55 | 56 | except Exception as e: # pylint: disable=broad-except 57 | print("[ERROR] DataFrame操作异常:", str(e)) 58 | 59 | return data_frame 60 | -------------------------------------------------------------------------------- /graphgen/models/generator/atomic_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from graphgen.bases import BaseGenerator 4 | from graphgen.templates import ATOMIC_GENERATION_PROMPT 5 | from graphgen.utils import compute_content_hash, detect_main_language, logger 6 | 7 | 8 | class AtomicGenerator(BaseGenerator): 9 | @staticmethod 10 | def build_prompt( 11 | batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] 12 | ) -> str: 13 | nodes, edges = batch 14 | context = "" 15 | for node in nodes: 16 | context += f"- {node[0]}: {node[1]['description']}\n" 17 | for edge in edges: 18 | context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n" 19 | language = detect_main_language(context) 20 | 21 | prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context) 22 | return prompt 23 | 24 | @staticmethod 25 | def parse_response(response: str) -> dict: 26 | """ 27 | AtomicGenerator normally generates one QA pair per response. 28 | So we just need to parse one QA pair from the response. 29 | :param response: 30 | :return: 31 | """ 32 | if "Question:" in response and "Answer:" in response: 33 | question = response.split("Question:")[1].split("Answer:")[0].strip() 34 | answer = response.split("Answer:")[1].strip() 35 | elif "问题:" in response and "答案:" in response: 36 | question = response.split("问题:")[1].split("答案:")[0].strip() 37 | answer = response.split("答案:")[1].strip() 38 | else: 39 | logger.warning("Failed to parse response: %s", response) 40 | return {} 41 | question = question.strip('"') 42 | answer = answer.strip('"') 43 | logger.debug("Question: %s", question) 44 | logger.debug("Answer: %s", answer) 45 | return { 46 | compute_content_hash(question): { 47 | "question": question, 48 | "answer": answer, 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | pwd = os.path.dirname(__file__) 7 | version_file = "graphgen/_version.py" 8 | 9 | 10 | def readme(): 11 | with open(os.path.join(pwd, "README.md"), encoding="utf-8") as f: 12 | content = f.read() 13 | return content 14 | 15 | 16 | def get_version(): 17 | with open(os.path.join(pwd, version_file), "r") as f: 18 | exec(compile(f.read(), version_file, "exec")) 19 | return locals()["__version__"] 20 | 21 | 22 | def read_requirements(): 23 | lines = [] 24 | with open("requirements.txt", "r") as f: 25 | for line in f.readlines(): 26 | if line.startswith("#"): 27 | continue 28 | if "textract" in line: 29 | continue 30 | if len(line) > 0: 31 | lines.append(line) 32 | return lines 33 | 34 | 35 | install_packages = read_requirements() 36 | 37 | if __name__ == "__main__": 38 | setup( 39 | name="graphg", 40 | version=get_version(), 41 | url="https://github.com/open-sciencelab/GraphGen", 42 | description="GraphGen: Enhancing Supervised Fine-Tuning for LLMs with Knowledge-Driven Synthetic Data Generation", 43 | long_description=readme(), 44 | long_description_content_type="text/markdown", 45 | author="open-sciencelab", 46 | author_email="open-sciencelab@pjlab.org.cn", 47 | packages=find_packages(exclude=["models"]), 48 | package_data={"GraphGen": ["configs/*"]}, 49 | include_package_data=True, 50 | install_requires=install_packages, 51 | classifiers=[ 52 | "Programming Language :: Python :: 3.10", 53 | "Programming Language :: Python :: 3.11", 54 | "Programming Language :: Python :: 3.12", 55 | "Intended Audience :: Developers", 56 | "Intended Audience :: Education", 57 | "Intended Audience :: Science/Research", 58 | ], 59 | entry_points={"console_scripts": ["graphg=graphgen.generate:main"]}, 60 | ) 61 | -------------------------------------------------------------------------------- /examples/input_examples/extract_demo.txt: -------------------------------------------------------------------------------- 1 | LEASE AGREEMENT 2 | This Lease Agreement (“Agreement”) is entered into on 15 March 2024 (“Signing Date”) by and between: 3 | Lessor: 4 | Riverside Commercial LLC, a Delaware limited-liability company with principal offices at 1800 Riverfront Drive, Wilmington, Delaware 19801, USA. 5 | Lessee: 6 | BrightWave Technologies Inc., a California corporation having its registered address at 2500 Innovation Way, San Jose, California 95110, USA. 7 | Leased Space: 8 | Approximately 5,200 rentable square feet on the entire third (3rd) floor of the Building known as Riverfront Plaza, 1800 Riverfront Drive, Wilmington, Delaware 19801, consisting of open-plan office areas, three private offices, one conference room, a kitchenette, and associated common-area restrooms. 9 | Designated Use: 10 | General office, research & development, and ancillary administrative purposes related to software and hardware technology businesses, and for no other purpose without prior written consent of Lessor. 11 | Term: 12 | The lease shall commence on 1 May 2024 (“Start Date”) and shall expire at 11:59 p.m. on 30 April 2029 (This is “End Date” or “Expiration Date of Lease”), unless sooner terminated in accordance with the terms hereof. 13 | Term of Payment: 14 | Base rent in the amount of Twenty-Five Thousand and 00/100 United States Dollars (US $25,000.00) per calendar month, due in advance on the first (1st) day of each month without demand, deduction, or set-off. Rent shall be paid by ACH transfer to the account designated in writing by Lessor. In addition, Lessee shall pay its proportionate share of Operating Expenses and Real Estate Taxes as defined in Section 5 of this Agreement. 15 | Extension Period: 16 | Provided Lessee is not in default beyond any applicable cure period, Lessee may extend this Lease for one (1) additional period of five (5) years by giving Lessor written notice at least twelve (12) months but not more than fifteen (15) months prior to the Expiration Date of Lease (i.e., 30 April 2029). The rent during the extension period shall be adjusted to the then-prevailing market rate as determined under Section 18.2 hereof. -------------------------------------------------------------------------------- /tests/e2e_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | 7 | def run_generate_test(tmp_path: Path, config_name: str): 8 | """ 9 | Run the generate test with the given configuration file and temporary path. 10 | 11 | Args: 12 | tmp_path: pytest temporary path 13 | config_name: configuration file name (e.g. "atomic_config.yaml") 14 | 15 | Returns: 16 | tuple: (run_folder, json_files[0]) 17 | """ 18 | repo_root = Path(__file__).resolve().parents[2] 19 | os.chdir(repo_root) 20 | 21 | config_path = repo_root / "graphgen" / "configs" / config_name 22 | output_dir = tmp_path / "output" 23 | output_dir.mkdir(parents=True, exist_ok=True) 24 | 25 | result = subprocess.run( 26 | [ 27 | "python", 28 | "-m", 29 | "graphgen.generate", 30 | "--config_file", 31 | str(config_path), 32 | "--output_dir", 33 | str(output_dir), 34 | ], 35 | capture_output=True, 36 | text=True, 37 | check=False, 38 | ) 39 | assert result.returncode == 0, f"Script failed with error: {result.stderr}" 40 | 41 | data_root = output_dir / "data" / "graphgen" 42 | assert data_root.exists(), f"{data_root} does not exist" 43 | run_folders = sorted(data_root.iterdir(), key=lambda p: p.name, reverse=True) 44 | assert run_folders, f"No run folders found in {data_root}" 45 | run_folder = run_folders[0] 46 | 47 | config_saved = run_folder / "config.yaml" 48 | assert config_saved.exists(), f"{config_saved} not found" 49 | 50 | json_files = list(run_folder.glob("*.json")) 51 | assert json_files, f"No JSON output found in {run_folder}" 52 | 53 | log_files = list(run_folder.glob("*.log")) 54 | assert log_files, "No log file generated" 55 | 56 | with open(json_files[0], "r", encoding="utf-8") as f: 57 | data = json.load(f) 58 | assert ( 59 | isinstance(data, list) and len(data) > 0 60 | ), "JSON output is empty or not a list" 61 | 62 | return run_folder, json_files[0] 63 | 64 | -------------------------------------------------------------------------------- /graphgen/models/reader/json_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Union 3 | 4 | import ray 5 | import ray.data 6 | 7 | from graphgen.bases.base_reader import BaseReader 8 | 9 | 10 | class JSONReader(BaseReader): 11 | """ 12 | Reader for JSON and JSONL files. 13 | Columns: 14 | - type: The type of the document (e.g., "text", "image", etc.) 15 | - if type is "text", "content" column must be present. 16 | """ 17 | 18 | def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: 19 | """ 20 | Read JSON file and return Ray Dataset. 21 | :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. 22 | :return: Ray Dataset containing validated and filtered data. 23 | """ 24 | if self.modalities and len(self.modalities) >= 2: 25 | ds: ray.data.Dataset = ray.data.from_items([]) 26 | for file in input_path if isinstance(input_path, list) else [input_path]: 27 | data = [] 28 | if file.endswith(".jsonl"): 29 | with open(file, "r", encoding="utf-8") as f: 30 | for line in f: 31 | item = json.loads(line) 32 | data.append(item) 33 | else: 34 | with open(file, "r", encoding="utf-8") as f: 35 | data = json.load(f) 36 | data = self._unify_schema(data) 37 | file_ds: ray.data.Dataset = ray.data.from_items(data) 38 | ds = ds.union(file_ds) # type: ignore 39 | else: 40 | ds = ray.data.read_json(input_path) 41 | ds = ds.map_batches(self._validate_batch, batch_format="pandas") 42 | ds = ds.filter(self._should_keep_item) 43 | return ds 44 | 45 | @staticmethod 46 | def _unify_schema(data): 47 | """ 48 | Unify schema for JSON data. 49 | """ 50 | for item in data: 51 | if "content" in item and isinstance(item["content"], dict): 52 | item["content"] = json.dumps(item["content"]) 53 | return data 54 | -------------------------------------------------------------------------------- /graphgen/models/generator/multi_hop_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from graphgen.bases import BaseGenerator 4 | from graphgen.templates import MULTI_HOP_GENERATION_PROMPT 5 | from graphgen.utils import compute_content_hash, detect_main_language, logger 6 | 7 | 8 | class MultiHopGenerator(BaseGenerator): 9 | @staticmethod 10 | def build_prompt( 11 | batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] 12 | ) -> str: 13 | nodes, edges = batch 14 | entities_str = "\n".join( 15 | [ 16 | f"{index + 1}. {node[0]}: {node[1]['description']}" 17 | for index, node in enumerate(nodes) 18 | ] 19 | ) 20 | 21 | relationships_str = "\n".join( 22 | [ 23 | f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" 24 | for index, edge in enumerate(edges) 25 | ] 26 | ) 27 | language = detect_main_language(entities_str + relationships_str) 28 | prompt = MULTI_HOP_GENERATION_PROMPT[language].format( 29 | entities=entities_str, relationships=relationships_str 30 | ) 31 | return prompt 32 | 33 | @staticmethod 34 | def parse_response(response: str) -> dict: 35 | if "Question:" in response and "Answer:" in response: 36 | question = response.split("Question:")[1].split("Answer:")[0].strip() 37 | answer = response.split("Answer:")[1].strip() 38 | elif "问题:" in response and "答案:" in response: 39 | question = response.split("问题:")[1].split("答案:")[0].strip() 40 | answer = response.split("答案:")[1].strip() 41 | else: 42 | logger.warning("Failed to parse response: %s", response) 43 | return {} 44 | question = question.strip('"') 45 | answer = answer.strip('"') 46 | logger.debug("Question: %s", question) 47 | logger.debug("Answer: %s", answer) 48 | return { 49 | compute_content_hash(question): { 50 | "question": question, 51 | "answer": answer, 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /graphgen/templates/search_judgement.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0301 2 | 3 | TEMPLATE: str = """-Goal- 4 | Please select the most relevant searcher result for the given entity. 5 | The name and description of the entity are provided. The searcher results are provided as a list. 6 | Please select the most relevant searcher result from the list. If none of the searcher results are relevant, please select 'None of the above'. 7 | 8 | Steps: 9 | 1. Read the name and description of the entity. 10 | 11 | 2. Read the searcher results. For each searcher result, compare it with the entity name and description to determine if it is relevant. 12 | 13 | 3. Select the most relevant searcher result from the list. If none of the searcher results are relevant, select 'None of the above'. 14 | 15 | 4. Output your selection directly, please do not provide any additional information. 16 | 17 | ################ 18 | -Examples- 19 | ################ 20 | {input_examples} 21 | 22 | ################ 23 | -Real Data- 24 | ################ 25 | Entity_name: {entity_name} 26 | Description: {description} 27 | Search Results: 28 | {search_results} 29 | ################ 30 | Output: 31 | """ 32 | 33 | EXAMPLES = [ 34 | """Example 1: 35 | ################ 36 | Entity_name: Java 37 | Description: Java is a high-level programming language developed by Sun Microsystems. It is used to create web applications, mobile applications, and enterprise software. 38 | Search Results: 39 | 1. Java (programming language) 40 | 2. Java (island) 41 | 3. Java (software platform) 42 | 4. Java (drink) 43 | 5. Java (disambiguation) 44 | 6. None of the above 45 | ################ 46 | Output: 47 | 1 48 | ################""", 49 | """Example 2: 50 | ################ 51 | Entity_name: Apple 52 | Description: Apple Inc. is an American multinational technology company that designs, manufactures, and sells consumer electronics, computer software, and online services. 53 | Search Results: 54 | 1. Apple (fruit) 55 | 2. Apple Inc. 56 | 3. Apple (disambiguation) 57 | 4. None of the above 58 | ################ 59 | Output: 60 | 2 61 | ################""", 62 | ] 63 | 64 | SEARCH_JUDGEMENT_PROMPT = { 65 | "TEMPLATE": TEMPLATE, 66 | "EXAMPLES": EXAMPLES, 67 | } 68 | -------------------------------------------------------------------------------- /resources/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Output Formats 3 | we support generating datasets in alpaca, sharegpt and chatml format. 4 | 5 | ## Alpaca Format 6 | ### Supervised Fine-Tuning Dataset 7 | - [Example](resources/output_examples/atomic_alpaca.json) 8 | In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response. 9 | ```json 10 | [ 11 | { 12 | "instruction": "user instruction (required)", 13 | "input": "user input (optional)", 14 | "output": "model response (required)" 15 | } 16 | ] 17 | ``` 18 | 19 | ## Sharegpt Format 20 | ### Supervised Fine-Tuning Dataset 21 | - [Example](resources/output_examples/cot_sharegpt.json) 22 | Compared to the alpaca format, the sharegpt format allows the datasets have more roles, such as human, gpt, observation and function. They are presented in a list of objects in the `conversations` column. 23 | 24 | Note that the human and observation should appear in odd positions, while gpt and function should appear in even positions. The gpt and function will be learned by the model. 25 | 26 | In our implementation, only `human` and `gpt` will be used. 27 | 28 | ```json 29 | [ 30 | { 31 | "conversations": [ 32 | { 33 | "from": "human", 34 | "value": "user instruction (required)" 35 | }, 36 | { 37 | "from": "gpt", 38 | "value": "model response (required)" 39 | } 40 | ] 41 | } 42 | ] 43 | ``` 44 | 45 | ## ChatML Format 46 | ### Supervised Fine-Tuning Dataset 47 | - [Example](resources/output_examples/aggregated_chatml.json) 48 | Like the sharegpt format, the chatml format also allows the datasets have more roles, such as user, assistant, system and tool. They are presented in a list of objects in the `messages` column. 49 | 50 | In our implementation, only `user` and `assistant` will be used. 51 | 52 | ```json 53 | [ 54 | { 55 | "messages": [ 56 | { 57 | "role": "user", 58 | "content": "user instruction (required)" 59 | }, 60 | { 61 | "role": "assistant", 62 | "content": "model response (required)" 63 | } 64 | ] 65 | } 66 | ] 67 | ``` 68 | -------------------------------------------------------------------------------- /tests/integration_tests/test_engine.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graphgen.engine import Context, Engine, op 4 | 5 | engine = Engine(max_workers=2) 6 | 7 | 8 | def test_simple_dag(capsys): 9 | """Verify the DAG A->B/C->D execution results and print order.""" 10 | ctx = Context() 11 | 12 | @op("A") 13 | def op_a(self, ctx): 14 | print("Running A") 15 | ctx.set("A", 1) 16 | 17 | @op("B", deps=["A"]) 18 | def op_b(self, ctx): 19 | print("Running B") 20 | ctx.set("B", ctx.get("A") + 1) 21 | 22 | @op("C", deps=["A"]) 23 | def op_c(self, ctx): 24 | print("Running C") 25 | ctx.set("C", ctx.get("A") + 2) 26 | 27 | @op("D", deps=["B", "C"]) 28 | def op_d(self, ctx): 29 | print("Running D") 30 | ctx.set("D", ctx.get("B") + ctx.get("C")) 31 | 32 | # Explicitly list the nodes to run; avoid relying on globals(). 33 | ops = [op_a, op_b, op_c, op_d] 34 | engine.run(ops, ctx) 35 | 36 | # Assert final results. 37 | assert ctx["A"] == 1 38 | assert ctx["B"] == 2 39 | assert ctx["C"] == 3 40 | assert ctx["D"] == 5 41 | 42 | # Assert print order: A must run before B and C; D must run after B and C. 43 | captured = capsys.readouterr().out.strip().splitlines() 44 | assert "Running A" in captured 45 | assert "Running B" in captured 46 | assert "Running C" in captured 47 | assert "Running D" in captured 48 | 49 | a_idx = next(i for i, line in enumerate(captured) if "Running A" in line) 50 | b_idx = next(i for i, line in enumerate(captured) if "Running B" in line) 51 | c_idx = next(i for i, line in enumerate(captured) if "Running C" in line) 52 | d_idx = next(i for i, line in enumerate(captured) if "Running D" in line) 53 | 54 | assert a_idx < b_idx 55 | assert a_idx < c_idx 56 | assert d_idx > b_idx 57 | assert d_idx > c_idx 58 | 59 | 60 | def test_cyclic_detection(): 61 | """A cyclic dependency should raise ValueError.""" 62 | ctx = Context() 63 | 64 | @op("X", deps=["Y"]) 65 | def op_x(self, ctx): 66 | pass 67 | 68 | @op("Y", deps=["X"]) 69 | def op_y(self, ctx): 70 | pass 71 | 72 | ops = [op_x, op_y] 73 | with pytest.raises(ValueError, match="Cyclic dependencies"): 74 | engine.run(ops, ctx) 75 | 76 | 77 | if __name__ == "__main__": 78 | pytest.main([__file__, "-v"]) 79 | -------------------------------------------------------------------------------- /graphgen/operators/build_kg/build_kg_service.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pandas as pd 4 | 5 | from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator 6 | from graphgen.bases.datatypes import Chunk 7 | from graphgen.common import init_llm, init_storage 8 | from graphgen.utils import logger 9 | 10 | from .build_mm_kg import build_mm_kg 11 | from .build_text_kg import build_text_kg 12 | 13 | 14 | class BuildKGService(BaseOperator): 15 | def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): 16 | super().__init__(working_dir=working_dir, op_name="build_kg_service") 17 | self.llm_client: BaseLLMWrapper = init_llm("synthesizer") 18 | self.graph_storage: BaseGraphStorage = init_storage( 19 | backend=graph_backend, working_dir=working_dir, namespace="graph" 20 | ) 21 | 22 | def process(self, batch: pd.DataFrame) -> pd.DataFrame: 23 | docs = batch.to_dict(orient="records") 24 | docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] 25 | 26 | # consume the chunks and build kg 27 | self.build_kg(docs) 28 | return pd.DataFrame([{"status": "kg_building_completed"}]) 29 | 30 | def build_kg(self, chunks: List[Chunk]) -> None: 31 | """ 32 | Build knowledge graph (KG) and merge into kg_instance 33 | """ 34 | text_chunks = [chunk for chunk in chunks if chunk.type == "text"] 35 | mm_chunks = [ 36 | chunk 37 | for chunk in chunks 38 | if chunk.type in ("image", "video", "table", "formula") 39 | ] 40 | 41 | if len(text_chunks) == 0: 42 | logger.info("All text chunks are already in the storage") 43 | else: 44 | logger.info("[Text Entity and Relation Extraction] processing ...") 45 | build_text_kg( 46 | llm_client=self.llm_client, 47 | kg_instance=self.graph_storage, 48 | chunks=text_chunks, 49 | ) 50 | if len(mm_chunks) == 0: 51 | logger.info("All multi-modal chunks are already in the storage") 52 | else: 53 | logger.info("[Multi-modal Entity and Relation Extraction] processing ...") 54 | build_mm_kg( 55 | llm_client=self.llm_client, 56 | kg_instance=self.graph_storage, 57 | chunks=mm_chunks, 58 | ) 59 | 60 | self.graph_storage.index_done_callback() 61 | -------------------------------------------------------------------------------- /graphgen/bases/base_partitioner.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, List 3 | 4 | from graphgen.bases.base_storage import BaseGraphStorage 5 | from graphgen.bases.datatypes import Community 6 | 7 | 8 | class BasePartitioner(ABC): 9 | @abstractmethod 10 | def partition( 11 | self, 12 | g: BaseGraphStorage, 13 | **kwargs: Any, 14 | ) -> List[Community]: 15 | """ 16 | Graph -> Communities 17 | :param g: Graph storage instance 18 | :param kwargs: Additional parameters for partitioning 19 | :return: List of communities 20 | """ 21 | 22 | @staticmethod 23 | def community2batch( 24 | comm: Community, g: BaseGraphStorage 25 | ) -> tuple[ 26 | list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] 27 | ]: 28 | """ 29 | Convert communities to batches of nodes and edges. 30 | :param comm: Community 31 | :param g: Graph storage instance 32 | :return: List of batches, each batch is a tuple of (nodes, edges) 33 | """ 34 | nodes = comm.nodes 35 | edges = comm.edges 36 | nodes_data = [] 37 | for node in nodes: 38 | node_data = g.get_node(node) 39 | if node_data: 40 | nodes_data.append((node, node_data)) 41 | edges_data = [] 42 | for u, v in edges: 43 | edge_data = g.get_edge(u, v) 44 | if edge_data: 45 | edges_data.append((u, v, edge_data)) 46 | else: 47 | edge_data = g.get_edge(v, u) 48 | if edge_data: 49 | edges_data.append((v, u, edge_data)) 50 | return nodes_data, edges_data 51 | 52 | @staticmethod 53 | def _build_adjacency_list( 54 | nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]] 55 | ) -> tuple[dict[str, List[str]], set[tuple[str, str]]]: 56 | """ 57 | Build adjacency list and edge set from nodes and edges. 58 | :param nodes 59 | :param edges 60 | :return: adjacency list, edge set 61 | """ 62 | adj: dict[str, List[str]] = {n[0]: [] for n in nodes} 63 | edge_set: set[tuple[str, str]] = set() 64 | for e in edges: 65 | adj[e[0]].append(e[1]) 66 | adj[e[1]].append(e[0]) 67 | edge_set.add((e[0], e[1])) 68 | edge_set.add((e[1], e[0])) 69 | return adj, edge_set 70 | -------------------------------------------------------------------------------- /graphgen/operators/generate/generate_service.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from graphgen.bases import BaseLLMWrapper, BaseOperator 4 | from graphgen.common import init_llm 5 | from graphgen.models import ( 6 | AggregatedGenerator, 7 | AtomicGenerator, 8 | CoTGenerator, 9 | MultiHopGenerator, 10 | VQAGenerator, 11 | ) 12 | from graphgen.utils import logger, run_concurrent 13 | 14 | 15 | class GenerateService(BaseOperator): 16 | """ 17 | Generate question-answer pairs based on nodes and edges. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | working_dir: str = "cache", 23 | method: str = "aggregated", 24 | data_format: str = "ChatML", 25 | ): 26 | super().__init__(working_dir=working_dir, op_name="generate_service") 27 | self.llm_client: BaseLLMWrapper = init_llm("synthesizer") 28 | 29 | self.method = method 30 | self.data_format = data_format 31 | 32 | if self.method == "atomic": 33 | self.generator = AtomicGenerator(self.llm_client) 34 | elif self.method == "aggregated": 35 | self.generator = AggregatedGenerator(self.llm_client) 36 | elif self.method == "multi_hop": 37 | self.generator = MultiHopGenerator(self.llm_client) 38 | elif self.method == "cot": 39 | self.generator = CoTGenerator(self.llm_client) 40 | elif self.method in ["vqa"]: 41 | self.generator = VQAGenerator(self.llm_client) 42 | else: 43 | raise ValueError(f"Unsupported generation mode: {method}") 44 | 45 | def process(self, batch: pd.DataFrame) -> pd.DataFrame: 46 | items = batch.to_dict(orient="records") 47 | return pd.DataFrame(self.generate(items)) 48 | 49 | def generate(self, items: list[dict]) -> list[dict]: 50 | """ 51 | Generate question-answer pairs based on nodes and edges. 52 | :param items 53 | :return: QA pairs 54 | """ 55 | logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) 56 | items = [(item["nodes"], item["edges"]) for item in items] 57 | results = run_concurrent( 58 | self.generator.generate, 59 | items, 60 | desc="[4/4]Generating QAs", 61 | unit="batch", 62 | ) 63 | 64 | results = self.generator.format_generation_results( 65 | results, output_data_format=self.data_format 66 | ) 67 | 68 | return results 69 | -------------------------------------------------------------------------------- /examples/generate/generate_aggregated_qa/aggregated_config.yaml: -------------------------------------------------------------------------------- 1 | global_params: 2 | working_dir: cache 3 | graph_backend: kuzu # graph database backend, support: kuzu, networkx 4 | kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv 5 | 6 | nodes: 7 | - id: read_files # id is unique in the pipeline, and can be referenced by other steps 8 | op_name: read 9 | type: source 10 | dependencies: [] 11 | params: 12 | input_path: 13 | - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples 14 | 15 | - id: chunk_documents 16 | op_name: chunk 17 | type: map_batch 18 | dependencies: 19 | - read_files 20 | execution_params: 21 | replicas: 4 22 | params: 23 | chunk_size: 1024 # chunk size for text splitting 24 | chunk_overlap: 100 # chunk overlap for text splitting 25 | 26 | - id: build_kg 27 | op_name: build_kg 28 | type: map_batch 29 | dependencies: 30 | - chunk_documents 31 | execution_params: 32 | replicas: 1 33 | batch_size: 128 34 | 35 | - id: quiz 36 | op_name: quiz 37 | type: aggregate 38 | dependencies: 39 | - build_kg 40 | execution_params: 41 | replicas: 1 42 | batch_size: 128 43 | params: 44 | quiz_samples: 2 # number of quiz samples to generate 45 | concurrency_limit: 200 46 | 47 | - id: judge 48 | op_name: judge 49 | type: map_batch 50 | dependencies: 51 | - quiz 52 | execution_params: 53 | replicas: 1 54 | batch_size: 128 55 | 56 | - id: partition 57 | op_name: partition 58 | type: aggregate 59 | dependencies: 60 | - judge 61 | params: 62 | method: ece # ece is a custom partition method based on comprehension loss 63 | method_params: 64 | max_units_per_community: 20 # max nodes and edges per community 65 | min_units_per_community: 5 # min nodes and edges per community 66 | max_tokens_per_community: 10240 # max tokens per community 67 | unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss 68 | 69 | - id: generate 70 | op_name: generate 71 | type: map_batch 72 | dependencies: 73 | - partition 74 | execution_params: 75 | replicas: 1 76 | batch_size: 128 77 | params: 78 | method: aggregated # atomic, aggregated, multi_hop, cot, vqa 79 | data_format: ChatML # Alpaca, Sharegpt, ChatML 80 | -------------------------------------------------------------------------------- /examples/search/build_db/build_protein_blast_db.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Downloads the latest release of UniProt, putting it in a release-specific directory. 6 | # Creates associated BLAST databases. 7 | # We need makeblastdb on our PATH 8 | # For Ubuntu/Debian: sudo apt install ncbi-blast+ 9 | # For CentOS/RHEL/Fedora: sudo dnf install ncbi-blast+ 10 | # Or download from: https://ftp.ncbi.nlm.nih.gov/blast/executables/blast+/LATEST/ 11 | 12 | # Better to use a stable DOWNLOAD_TMP name to support resuming downloads 13 | DOWNLOAD_TMP=_downloading 14 | mkdir -p ${DOWNLOAD_TMP} 15 | cd ${DOWNLOAD_TMP} 16 | 17 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/RELEASE.metalink" 18 | 19 | # Extract the release name (like 2017_10 or 2017_1) 20 | # Use sed for cross-platform compatibility (works on both macOS and Linux) 21 | RELEASE=$(sed -n 's/.*\([0-9]\{4\}_[0-9]\{1,2\}\)<\/version>.*/\1/p' RELEASE.metalink | head -1) 22 | 23 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz" 24 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz" 25 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/reldate.txt" 26 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/README" 27 | wget -c "ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/LICENSE" 28 | 29 | cd .. 30 | 31 | mkdir ${RELEASE} 32 | mv ${DOWNLOAD_TMP}/* ${RELEASE} 33 | rmdir ${DOWNLOAD_TMP} 34 | 35 | cd ${RELEASE} 36 | 37 | gunzip uniprot_sprot.fasta.gz 38 | gunzip uniprot_trembl.fasta.gz 39 | 40 | cat uniprot_sprot.fasta uniprot_trembl.fasta >uniprot_${RELEASE}.fasta 41 | 42 | makeblastdb -in uniprot_${RELEASE}.fasta -out uniprot_${RELEASE} -dbtype prot -parse_seqids -title uniprot_${RELEASE} 43 | makeblastdb -in uniprot_sprot.fasta -out uniprot_sprot -dbtype prot -parse_seqids -title uniprot_sprot 44 | makeblastdb -in uniprot_trembl.fasta -out uniprot_trembl -dbtype prot -parse_seqids -title uniprot_trembl 45 | 46 | cd .. 47 | 48 | echo "BLAST databases created successfully!" 49 | echo "Database locations:" 50 | echo " - Combined: $(pwd)/${RELEASE}/uniprot_${RELEASE}" 51 | echo " - Swiss-Prot: $(pwd)/${RELEASE}/uniprot_sprot" 52 | echo " - TrEMBL: $(pwd)/${RELEASE}/uniprot_trembl" 53 | echo "" 54 | echo "To use these databases, set in your config:" 55 | echo " local_blast_db: $(pwd)/${RELEASE}/uniprot_sprot # or uniprot_${RELEASE} or uniprot_trembl" 56 | 57 | -------------------------------------------------------------------------------- /graphgen/models/partitioner/dfs_partitioner.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import Iterable 3 | from typing import Any 4 | 5 | from graphgen.bases import BaseGraphStorage, BasePartitioner 6 | from graphgen.bases.datatypes import Community 7 | 8 | NODE_UNIT: str = "n" 9 | EDGE_UNIT: str = "e" 10 | 11 | 12 | class DFSPartitioner(BasePartitioner): 13 | """ 14 | DFS partitioner that partitions the graph into communities of a fixed size. 15 | 1. Randomly choose a unit. 16 | 2. Random walk using DFS until the community reaches the max unit size. 17 | (In GraphGen, a unit is defined as a node or an edge.) 18 | """ 19 | 20 | def partition( 21 | self, 22 | g: BaseGraphStorage, 23 | max_units_per_community: int = 1, 24 | **kwargs: Any, 25 | ) -> Iterable[Community]: 26 | nodes = g.get_all_nodes() 27 | edges = g.get_all_edges() 28 | 29 | adj, _ = self._build_adjacency_list(nodes, edges) 30 | 31 | used_n: set[str] = set() 32 | used_e: set[frozenset[str]] = set() 33 | 34 | units = [(NODE_UNIT, n[0]) for n in nodes] + [ 35 | (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges 36 | ] 37 | random.shuffle(units) 38 | 39 | for kind, seed in units: 40 | if (kind == NODE_UNIT and seed in used_n) or ( 41 | kind == EDGE_UNIT and seed in used_e 42 | ): 43 | continue 44 | 45 | comm_n, comm_e = [], [] 46 | stack = [(kind, seed)] 47 | cnt = 0 48 | 49 | while stack and cnt < max_units_per_community: 50 | k, it = stack.pop() 51 | if k == NODE_UNIT: 52 | if it in used_n: 53 | continue 54 | used_n.add(it) 55 | comm_n.append(it) 56 | cnt += 1 57 | for nei in adj[it]: 58 | e_key = frozenset((it, nei)) 59 | if e_key not in used_e: 60 | stack.append((EDGE_UNIT, e_key)) 61 | break 62 | else: 63 | if it in used_e: 64 | continue 65 | used_e.add(it) 66 | comm_e.append(tuple(it)) 67 | cnt += 1 68 | # push neighboring nodes 69 | for n in it: 70 | if n not in used_n: 71 | stack.append((NODE_UNIT, n)) 72 | 73 | if comm_n or comm_e: 74 | yield Community(id=seed, nodes=comm_n, edges=comm_e) 75 | -------------------------------------------------------------------------------- /webui/examples/txt_demo.txt: -------------------------------------------------------------------------------- 1 | 云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。 2 | 3 | 隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。 4 | 5 | Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity. 6 | 7 | Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture. 8 | -------------------------------------------------------------------------------- /baselines/EntiGraph/tasks/task_abc.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from abc import abstractmethod 3 | 4 | 5 | class Question: 6 | def __init__(self, statement: Dict, answer: str, attempts: List[Dict], formatted_prompt: str = ""): 7 | self.statement = statement 8 | self.answer = answer 9 | self.attempts = attempts 10 | self.formatted_prompt = formatted_prompt 11 | 12 | @abstractmethod 13 | def prompt(self): 14 | pass 15 | 16 | @abstractmethod 17 | def iscorrect(self, attempt_index: int = 0): 18 | pass 19 | 20 | @abstractmethod 21 | def asdict(self): 22 | pass 23 | 24 | @abstractmethod 25 | def llama_parse_answer(self): 26 | pass 27 | 28 | 29 | class Document: 30 | def __init__(self, text: str, questions: List[Dict]): 31 | self.text = text 32 | self.questions = questions 33 | 34 | @property 35 | @abstractmethod 36 | def uid(self): 37 | pass 38 | 39 | @property 40 | @abstractmethod 41 | def content(self): 42 | pass 43 | 44 | @abstractmethod 45 | def question_prompts(self, add_document_context: bool, add_thought_process: bool, sep_after_question: str): 46 | pass 47 | 48 | @abstractmethod 49 | def asdict(self): 50 | pass 51 | 52 | def majority_vote(self, n_samples): 53 | for question in self.questions: 54 | question.majority_vote(n_samples) 55 | 56 | 57 | class Task: 58 | openai_system_generate_entities: str 59 | openai_system_generate_two_entity_relations: str 60 | openai_system_generate_three_entity_relations: str 61 | llama_cot_prompt: str 62 | 63 | def __init__(self, name, documents: List[Document]): 64 | self.name = name 65 | self.documents = documents 66 | 67 | @abstractmethod 68 | def load_attempts_json(self, file_path: str): 69 | pass 70 | 71 | @abstractmethod 72 | def performance_stats(self): 73 | pass 74 | 75 | def all_questions(self, add_document_context: bool, add_thought_process: bool, sep_after_question: str): 76 | prompts = [] 77 | for document in self.documents: 78 | prompts += document.question_prompts(add_document_context, add_thought_process, sep_after_question) 79 | 80 | return prompts 81 | 82 | @property 83 | def all_document_contents(self): 84 | return '\n'.join([document.content for document in self.documents]) 85 | 86 | def asdict(self): 87 | return [document.asdict() for document in self.documents] 88 | 89 | def majority_vote(self, n_samples: int = 1): 90 | for document in self.documents: 91 | document.majority_vote(n_samples) 92 | -------------------------------------------------------------------------------- /examples/input_examples/txt_demo.txt: -------------------------------------------------------------------------------- 1 | 云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。 2 | 3 | 隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。 4 | 5 | Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity. 6 | 7 | Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture. 8 | -------------------------------------------------------------------------------- /graphgen/bases/base_llm_wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | import re 5 | from typing import Any, List, Optional 6 | 7 | from graphgen.bases.base_tokenizer import BaseTokenizer 8 | from graphgen.bases.datatypes import Token 9 | 10 | 11 | class BaseLLMWrapper(abc.ABC): 12 | """ 13 | LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...). 14 | """ 15 | 16 | def __init__( 17 | self, 18 | *, 19 | system_prompt: str = "", 20 | temperature: float = 0.0, 21 | max_tokens: int = 4096, 22 | repetition_penalty: float = 1.05, 23 | top_p: float = 0.95, 24 | top_k: int = 50, 25 | tokenizer: Optional[BaseTokenizer] = None, 26 | **kwargs: Any, 27 | ): 28 | self.system_prompt = system_prompt 29 | self.temperature = temperature 30 | self.max_tokens = max_tokens 31 | self.repetition_penalty = repetition_penalty 32 | self.top_p = top_p 33 | self.top_k = top_k 34 | self.tokenizer = tokenizer 35 | 36 | for k, v in kwargs.items(): 37 | setattr(self, k, v) 38 | 39 | @abc.abstractmethod 40 | async def generate_answer( 41 | self, text: str, history: Optional[List[str]] = None, **extra: Any 42 | ) -> str: 43 | """Generate answer from the model.""" 44 | raise NotImplementedError 45 | 46 | @abc.abstractmethod 47 | async def generate_topk_per_token( 48 | self, text: str, history: Optional[List[str]] = None, **extra: Any 49 | ) -> List[Token]: 50 | """Generate top-k tokens for the next token prediction.""" 51 | raise NotImplementedError 52 | 53 | @abc.abstractmethod 54 | async def generate_inputs_prob( 55 | self, text: str, history: Optional[List[str]] = None, **extra: Any 56 | ) -> List[Token]: 57 | """Generate probabilities for each token in the input.""" 58 | raise NotImplementedError 59 | 60 | @staticmethod 61 | def filter_think_tags(text: str, think_tag: str = "think") -> str: 62 | """ 63 | Remove tags from the text. 64 | - If the text contains and , it removes everything between them and the tags themselves. 65 | - If the text contains only , it removes content before the tag. 66 | """ 67 | paired_pattern = re.compile(rf"<{think_tag}>.*?", re.DOTALL) 68 | filtered = paired_pattern.sub("", text) 69 | 70 | orphan_pattern = re.compile(rf"^.*?", re.DOTALL) 71 | filtered = orphan_pattern.sub("", filtered) 72 | 73 | filtered = filtered.strip() 74 | return filtered if filtered else text.strip() 75 | -------------------------------------------------------------------------------- /webui/examples/csv_demo.csv: -------------------------------------------------------------------------------- 1 | type,content 2 | text,云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。 3 | text,隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。 4 | text,"Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity." 5 | text,"Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture." 6 | -------------------------------------------------------------------------------- /graphgen/models/storage/kv/rocksdb_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Set 4 | 5 | # rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it 6 | # pylint: disable=no-name-in-module 7 | from rocksdict import Rdict 8 | 9 | from graphgen.bases.base_storage import BaseKVStorage 10 | 11 | 12 | @dataclass 13 | class RocksDBKVStorage(BaseKVStorage): 14 | _db: Rdict = None 15 | _db_path: str = None 16 | 17 | def __post_init__(self): 18 | self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db") 19 | self._db = Rdict(self._db_path) 20 | print( 21 | f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'" 22 | ) 23 | 24 | @property 25 | def data(self): 26 | return self._db 27 | 28 | def all_keys(self) -> List[str]: 29 | return list(self._db.keys()) 30 | 31 | def index_done_callback(self): 32 | self._db.flush() 33 | print(f"RocksDB flushed for {self.namespace}") 34 | 35 | def get_by_id(self, id: str) -> Any: 36 | return self._db.get(id, None) 37 | 38 | def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]: 39 | result = [] 40 | for index in ids: 41 | item = self._db.get(index, None) 42 | if item is None: 43 | result.append(None) 44 | continue 45 | 46 | if fields is None: 47 | result.append(item) 48 | else: 49 | result.append({k: v for k, v in item.items() if k in fields}) 50 | return result 51 | 52 | def get_all(self) -> Dict[str, Dict]: 53 | return dict(self._db) 54 | 55 | def filter_keys(self, data: List[str]) -> Set[str]: 56 | return {s for s in data if s not in self._db} 57 | 58 | def upsert(self, data: Dict[str, Any]): 59 | left_data = {} 60 | for k, v in data.items(): 61 | if k not in self._db: 62 | left_data[k] = v 63 | 64 | if left_data: 65 | for k, v in left_data.items(): 66 | self._db[k] = v 67 | # if left_data is very large, it is recommended to use self._db.write_batch() for optimization 68 | 69 | return left_data 70 | 71 | def drop(self): 72 | self._db.close() 73 | Rdict.destroy(self._db_path) 74 | self._db = Rdict(self._db_path) 75 | print(f"Dropped RocksDB {self.namespace}") 76 | 77 | def close(self): 78 | if self._db: 79 | self._db.close() 80 | 81 | def reload(self): 82 | """For databases that need reloading, RocksDB auto-manages this.""" 83 | -------------------------------------------------------------------------------- /examples/input_examples/csv_demo.csv: -------------------------------------------------------------------------------- 1 | type,content 2 | text,云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。 3 | text,隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。 4 | text,"Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity." 5 | text,"Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture." 6 | -------------------------------------------------------------------------------- /graphgen/models/partitioner/bfs_partitioner.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import deque 3 | from typing import Any, Iterable, List 4 | 5 | from graphgen.bases import BaseGraphStorage, BasePartitioner 6 | from graphgen.bases.datatypes import Community 7 | 8 | NODE_UNIT: str = "n" 9 | EDGE_UNIT: str = "e" 10 | 11 | 12 | class BFSPartitioner(BasePartitioner): 13 | """ 14 | BFS partitioner that partitions the graph into communities of a fixed size. 15 | 1. Randomly choose a unit. 16 | 2. Expand the community using BFS until the max unit size is reached. 17 | (A unit is a node or an edge.) 18 | """ 19 | 20 | def partition( 21 | self, 22 | g: BaseGraphStorage, 23 | max_units_per_community: int = 1, 24 | **kwargs: Any, 25 | ) -> Iterable[Community]: 26 | nodes = g.get_all_nodes() 27 | edges = g.get_all_edges() 28 | 29 | adj, _ = self._build_adjacency_list(nodes, edges) 30 | 31 | used_n: set[str] = set() 32 | used_e: set[frozenset[str]] = set() 33 | 34 | units = [(NODE_UNIT, n[0]) for n in nodes] + [ 35 | (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges 36 | ] 37 | random.shuffle(units) 38 | 39 | for kind, seed in units: 40 | if (kind == NODE_UNIT and seed in used_n) or ( 41 | kind == EDGE_UNIT and seed in used_e 42 | ): 43 | continue 44 | 45 | comm_n: List[str] = [] 46 | comm_e: List[tuple[str, str]] = [] 47 | queue: deque[tuple[str, Any]] = deque([(kind, seed)]) 48 | cnt = 0 49 | 50 | while queue and cnt < max_units_per_community: 51 | k, it = queue.popleft() 52 | if k == NODE_UNIT: 53 | if it in used_n: 54 | continue 55 | used_n.add(it) 56 | comm_n.append(it) 57 | cnt += 1 58 | for nei in adj[it]: 59 | e_key = frozenset((it, nei)) 60 | if e_key not in used_e: 61 | queue.append((EDGE_UNIT, e_key)) 62 | else: 63 | if it in used_e: 64 | continue 65 | used_e.add(it) 66 | 67 | u, v = it 68 | comm_e.append((u, v)) 69 | cnt += 1 70 | # push nodes that are not visited 71 | for n in it: 72 | if n not in used_n: 73 | queue.append((NODE_UNIT, n)) 74 | 75 | if comm_n or comm_e: 76 | yield Community(id=seed, nodes=comm_n, edges=comm_e) 77 | -------------------------------------------------------------------------------- /graphgen/models/evaluator/mtld_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | from graphgen.bases.datatypes import QAPair 4 | from graphgen.models.evaluator.base_evaluator import BaseEvaluator 5 | from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language 6 | 7 | nltk_helper = NLTKHelper() 8 | 9 | 10 | class MTLDEvaluator(BaseEvaluator): 11 | """ 12 | 衡量文本词汇多样性的指标 13 | """ 14 | 15 | def __init__(self, max_concurrent: int = 100): 16 | super().__init__(max_concurrent) 17 | self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english")) 18 | self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese")) 19 | 20 | async def evaluate_single(self, pair: QAPair) -> float: 21 | loop = create_event_loop() 22 | return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) 23 | 24 | def _calculate_mtld_score(self, text: str, threshold=0.72) -> float: 25 | """ 26 | 计算MTLD (向前和向后的平均值) 27 | 28 | min is 1.0 29 | higher is better 30 | """ 31 | if not text or not text.strip(): 32 | return 0.0 33 | 34 | lang = detect_main_language(text) 35 | tokens = nltk_helper.word_tokenize(text, lang) 36 | 37 | stopwords = self.stopwords_zh if lang == "zh" else self.stopwords_en 38 | filtered_tokens = [word for word in tokens if word not in stopwords] 39 | filtered_tokens = [word for word in filtered_tokens if word.isalnum()] 40 | 41 | if not filtered_tokens: 42 | return 0 43 | 44 | # 计算向前的MTLD 45 | forward_factors = self._compute_factors(filtered_tokens, threshold) 46 | 47 | # 计算向后的MTLD 48 | backward_factors = self._compute_factors(filtered_tokens[::-1], threshold) 49 | 50 | # 取平均值 51 | return (forward_factors + backward_factors) / 2 52 | 53 | @staticmethod 54 | def _compute_factors(tokens: list, threshold: float) -> float: 55 | factors = 0 56 | current_segment = [] 57 | unique_words = set() 58 | 59 | for token in tokens: 60 | current_segment.append(token) 61 | unique_words.add(token) 62 | ttr = len(unique_words) / len(current_segment) 63 | 64 | if ttr <= threshold: 65 | factors += 1 66 | current_segment = [] 67 | unique_words = set() 68 | 69 | # 处理最后一个不完整片段 70 | if current_segment: 71 | ttr = len(unique_words) / len(current_segment) 72 | if ttr <= threshold: 73 | factors += 1 74 | else: 75 | factors += 1 - (ttr - threshold) / (1 - threshold) 76 | 77 | return len(tokens) / factors if factors > 0 else len(tokens) 78 | -------------------------------------------------------------------------------- /webui/examples/jsonl_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"type": "text", "content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"} 2 | {"type": "text", "content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"} 3 | {"type": "text", "content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."} 4 | {"type": "text", "content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} 5 | -------------------------------------------------------------------------------- /examples/input_examples/jsonl_demo.jsonl: -------------------------------------------------------------------------------- 1 | {"type": "text", "content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"} 2 | {"type": "text", "content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"} 3 | {"type": "text", "content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."} 4 | {"type": "text", "content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} 5 | -------------------------------------------------------------------------------- /baselines/EntiGraph/tasks/baseline_task.py: -------------------------------------------------------------------------------- 1 | # Rewrite from https://github.com/ZitongYang/Synthetic_Continued_Pretraining/blob/main/tasks/quality.py 2 | 3 | import json 4 | from hashlib import md5 5 | 6 | from baselines.EntiGraph.tasks.task_abc import Document, Task 7 | from baselines.EntiGraph.entigraph_utils.prompt_utils import ( 8 | OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES, 9 | OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS, 10 | OPENAI_API_SYSTEM_QUALITY_GENERATE_THREE_ENTITY_RELATIONS, 11 | QUALITY_FEW_SHOT_COT_PROMPT, OPENAI_API_SYSTEM_QUALITY_QA_SFT) 12 | 13 | class BaselineTask(Task): 14 | openai_system_generate_entities = OPENAI_API_SYSTEM_QUALITY_GENERATE_ENTITIES 15 | openai_system_generate_two_entity_relations = OPENAI_API_SYSTEM_QUALITY_GENERATE_TWO_ENTITY_RELATIONS 16 | openai_system_generate_three_entity_relations = OPENAI_API_SYSTEM_QUALITY_GENERATE_THREE_ENTITY_RELATIONS 17 | openai_system_quality_qa_sft = OPENAI_API_SYSTEM_QUALITY_QA_SFT 18 | llama_cot_prompt = QUALITY_FEW_SHOT_COT_PROMPT 19 | 20 | def __init__(self, input_file: str, data_type: str): 21 | self._data = self._load_split(input_file, data_type) 22 | self._create_documents() 23 | self._dedup() 24 | 25 | @staticmethod 26 | def _load_split(input_file: str, data_type: str): 27 | if data_type == 'raw': 28 | with open(input_file, "r", encoding='utf-8') as f: 29 | data = [json.loads(line) for line in f] 30 | data = [[chunk] for chunk in data] 31 | elif data_type == 'chunked': 32 | with open(input_file, "r", encoding='utf-8') as f: 33 | data = json.load(f) 34 | 35 | documents = [] 36 | for doc in data: 37 | for chunk in doc: 38 | documents.append(chunk) 39 | return documents 40 | 41 | def _create_documents(self): 42 | documents = [] 43 | for adict in self._data: 44 | document = Document(text=adict['content'], questions=[]) 45 | documents.append(document) 46 | super().__init__('baseline', documents) 47 | 48 | def _dedup(self): 49 | deuped_documents = {} 50 | for document in self.documents: 51 | key = compute_content_hash(document.text) 52 | if key not in deuped_documents: 53 | deuped_documents[key] = document 54 | 55 | self.documents = list(deuped_documents.values()) 56 | 57 | 58 | def performance_stats(self): 59 | pass 60 | 61 | def load_attempts_json(self, file_path: str): 62 | pass 63 | 64 | 65 | def compute_content_hash(content, prefix: str = ""): 66 | return prefix + md5(content.encode()).hexdigest() 67 | -------------------------------------------------------------------------------- /webui/examples/json_demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"type": "text", "content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"}, 3 | {"type": "text", "content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"}, 4 | {"type": "text", "content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."}, 5 | {"type": "text", "content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} 6 | ] 7 | -------------------------------------------------------------------------------- /webui/examples/vqa_demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"type": "text", "content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"}, 3 | {"type": "text", "content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"}, 4 | {"type": "text", "content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."}, 5 | {"type": "text", "content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} 6 | ] 7 | -------------------------------------------------------------------------------- /examples/input_examples/json_demo.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"type": "text", "content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"}, 3 | {"type": "text", "content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"}, 4 | {"type": "text", "content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."}, 5 | {"type": "text", "content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} 6 | ] 7 | -------------------------------------------------------------------------------- /tests/integration_tests/models/partitioner/test_bfs_partitioner.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import pytest 4 | 5 | from graphgen.bases.datatypes import Community 6 | from graphgen.models import BFSPartitioner, NetworkXStorage 7 | 8 | 9 | @pytest.mark.asyncio 10 | async def test_empty_graph(): 11 | with tempfile.TemporaryDirectory() as tmpdir: 12 | storage = NetworkXStorage(working_dir=tmpdir, namespace="empty") 13 | partitioner = BFSPartitioner() 14 | communities = await partitioner.partition(storage, max_units_per_community=5) 15 | assert communities == [] 16 | 17 | 18 | @pytest.mark.asyncio 19 | async def test_single_node(): 20 | nodes = [("A", {"desc": "alone"})] 21 | edges = [] 22 | with tempfile.TemporaryDirectory() as tmpdir: 23 | storage = NetworkXStorage(working_dir=tmpdir, namespace="single_node") 24 | 25 | for nid, ndata in nodes: 26 | await storage.upsert_node(nid, ndata) 27 | for src, tgt, edata in edges: 28 | await storage.upsert_edge(src, tgt, edata) 29 | 30 | partitioner = BFSPartitioner() 31 | communities: list[Community] = await partitioner.partition( 32 | storage, max_units_per_community=5 33 | ) 34 | assert len(communities) == 1 35 | assert communities[0].nodes == ["A"] 36 | assert communities[0].edges == [] 37 | 38 | 39 | @pytest.mark.asyncio 40 | async def test_small_graph(): 41 | """ 42 | 0 - 1 - 2 43 | | | | 44 | 3 - 4 - 5 45 | 6 nodes & 7 edges, max_units=4 => at least 3 communities 46 | """ 47 | nodes = [(str(i), {"desc": f"node{i}"}) for i in range(6)] 48 | edges = [ 49 | ("0", "1", {"desc": "e01"}), 50 | ("1", "2", {"desc": "e12"}), 51 | ("0", "3", {"desc": "e03"}), 52 | ("1", "4", {"desc": "e14"}), 53 | ("2", "5", {"desc": "e25"}), 54 | ("3", "4", {"desc": "e34"}), 55 | ("4", "5", {"desc": "e45"}), 56 | ] 57 | 58 | with tempfile.TemporaryDirectory() as tmpdir: 59 | storage = NetworkXStorage(working_dir=tmpdir, namespace="small_graph") 60 | 61 | for nid, ndata in nodes: 62 | await storage.upsert_node(nid, ndata) 63 | for src, tgt, edata in edges: 64 | await storage.upsert_edge(src, tgt, edata) 65 | 66 | partitioner = BFSPartitioner() 67 | communities: list[Community] = await partitioner.partition( 68 | storage, max_units_per_community=4 69 | ) 70 | 71 | assert len(communities) <= 5 72 | 73 | all_nodes = set() 74 | all_edges = set() 75 | for c in communities: 76 | assert len(c.nodes) + len(c.edges) <= 4 77 | all_nodes.update(c.nodes) 78 | all_edges.update(c.edges) 79 | 80 | assert all_nodes == {str(i) for i in range(6)} 81 | assert len(all_edges) == 7 82 | -------------------------------------------------------------------------------- /graphgen/models/generator/quiz_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from graphgen.bases import BaseGenerator 4 | from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT 5 | from graphgen.utils import detect_main_language, logger 6 | 7 | 8 | class QuizGenerator(BaseGenerator): 9 | """ 10 | Quiz Generator rephrases given descriptions to create quiz questions. 11 | """ 12 | 13 | @staticmethod 14 | def build_prompt( 15 | batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] 16 | ) -> str: 17 | """ 18 | Build prompt for rephrasing the description. 19 | :param batch: A tuple containing (nodes, edges) where nodes/edges 20 | contain description information 21 | :return: Prompt string 22 | """ 23 | # Extract description from batch 24 | # For quiz generator, we expect a special format where 25 | # the description is passed as the first node's description 26 | nodes, edges = batch 27 | if nodes: 28 | description = nodes[0][1].get("description", "") 29 | template_type = nodes[0][1].get("template_type", "TEMPLATE") 30 | elif edges: 31 | description = edges[0][2].get("description", "") 32 | template_type = edges[0][2].get("template_type", "TEMPLATE") 33 | else: 34 | raise ValueError("Batch must contain at least one node or edge with description") 35 | 36 | return QuizGenerator.build_prompt_for_description(description, template_type) 37 | 38 | @staticmethod 39 | def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str: 40 | """ 41 | Build prompt for rephrasing a single description. 42 | :param description: The description to rephrase 43 | :param template_type: Either "TEMPLATE" (same meaning) or "ANTI_TEMPLATE" (opposite meaning) 44 | :return: Prompt string 45 | """ 46 | language = detect_main_language(description) 47 | prompt = DESCRIPTION_REPHRASING_PROMPT[language][template_type].format( 48 | input_sentence=description 49 | ) 50 | return prompt 51 | 52 | @staticmethod 53 | def parse_rephrased_text(response: str) -> str: 54 | """ 55 | Parse the rephrased text from the response. 56 | :param response: 57 | :return: 58 | """ 59 | rephrased_text = response.strip().strip('"') 60 | logger.debug("Rephrased Text: %s", rephrased_text) 61 | return rephrased_text 62 | 63 | @staticmethod 64 | def parse_response(response: str) -> Any: 65 | """ 66 | Parse the LLM response. For quiz generator, this returns the rephrased text. 67 | :param response: LLM response 68 | :return: Rephrased text 69 | """ 70 | return QuizGenerator.parse_rephrased_text(response) 71 | -------------------------------------------------------------------------------- /.github/workflows/sync-demo.yml: -------------------------------------------------------------------------------- 1 | name: Sync Demo Branch 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | jobs: 10 | sync-demo: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout main branch 15 | uses: actions/checkout@v4 16 | with: 17 | ref: main 18 | token: ${{ secrets.GITHUB_TOKEN }} 19 | 20 | - name: Create demo branch if it doesn't exist 21 | run: | 22 | if ! git ls-remote --heads origin demo | grep -q demo; then 23 | echo "Creating demo branch..." 24 | git checkout -b demo 25 | git push origin demo 26 | else 27 | echo "Demo branch already exists" 28 | fi 29 | 30 | - name: Checkout demo branch 31 | uses: actions/checkout@v4 32 | with: 33 | ref: demo 34 | token: ${{ secrets.GITHUB_TOKEN }} 35 | path: demo 36 | 37 | - name: Clean demo directory 38 | run: | 39 | cd demo 40 | find . -mindepth 1 -path './.git' -prune -o -exec rm -rf {} + 2>/dev/null || true 41 | 42 | - name: Copy files using config 43 | run: | 44 | yq eval '.sync[] | .source + ":" + .dest' .github/sync-config.yml | while IFS=: read -r src dst; do 45 | src=$(echo "$src" | xargs) 46 | dst=$(echo "$dst" | xargs) 47 | 48 | [ -z "$src" ] && continue 49 | 50 | if [ -e "$src" ]; then 51 | target_path="demo/$dst" 52 | 53 | # 处理目录(以/结尾或本身是目录) 54 | if [[ "$dst" == */ ]] || [ -d "$src" ]; then 55 | mkdir -p "$target_path" 56 | # 复制目录里的内容,而不是目录本身 57 | cp -r "$src"/* "$target_path" 58 | echo "Copied $src/* → $target_path" 59 | else 60 | mkdir -p "$(dirname "$target_path")" 61 | cp "$src" "$target_path" 62 | echo "Copied $src → $target_path" 63 | fi 64 | else 65 | echo "Source not found: $src" 66 | fi 67 | done 68 | 69 | - name: Commit and push changes 70 | run: | 71 | cd demo 72 | git config --global user.email "actions@github.com" 73 | git config --global user.name "github-actions[bot]" 74 | 75 | # 检查是否有变化 76 | if [[ -n $(git status --porcelain) ]]; then 77 | git add . 78 | git commit -m "Auto-sync demo branch with main branch" 79 | git push origin demo 80 | echo "Changes pushed to demo branch" 81 | else 82 | echo "No changes to sync" 83 | fi 84 | 85 | push-hf: 86 | needs: sync-demo 87 | uses: ./.github/workflows/push-to-hf.yml 88 | secrets: 89 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 90 | push-ms: 91 | needs: sync-demo 92 | uses: ./.github/workflows/push-to-ms.yml 93 | secrets: 94 | MS_TOKEN: ${{ secrets.MS_TOKEN }} 95 | with: 96 | ref: demo 97 | --------------------------------------------------------------------------------