├── .python-version
├── src
└── swe_care
│ ├── __init__.py
│ ├── collect
│ ├── __init__.py
│ ├── get_top_repos.py
│ └── __main__.py
│ ├── harness
│ ├── __init__.py
│ ├── evaluators
│ │ ├── utils.py
│ │ ├── __init__.py
│ │ └── repo_level.py
│ ├── __main__.py
│ └── code_review_eval.py
│ ├── schema
│ ├── __init__.py
│ ├── inference.py
│ ├── evaluation.py
│ ├── collect.py
│ └── dataset.py
│ ├── utils
│ ├── __init__.py
│ ├── github_graphql
│ │ ├── GetLabels.graphql
│ │ ├── GetIssueLabels.graphql
│ │ ├── GetThreadComments.graphql
│ │ ├── GetIssueComments.graphql
│ │ ├── GetClosingIssues.graphql
│ │ ├── GetCommits.graphql
│ │ ├── GetReviewThreads.graphql
│ │ ├── GetReviewComments.graphql
│ │ ├── GetReviews.graphql
│ │ ├── __init__.py
│ │ ├── GetSpecificPullRequest.graphql
│ │ └── GetMergedPullRequests.graphql
│ ├── prompt_loader.py
│ ├── read_file.py
│ ├── patch.py
│ ├── file_source_retrieval.py
│ ├── load.py
│ ├── llm_models
│ │ ├── __init__.py
│ │ └── clients.py
│ ├── estimate.py
│ └── github.py
│ ├── inference
│ ├── __init__.py
│ └── __main__.py
│ └── templates
│ ├── rm_sample.yaml
│ ├── questionnaire
│ ├── llm_respondent.yaml
│ └── questionnaire.md.j2
│ ├── classify_review_effort.yaml
│ ├── classify_problem_domain.yaml
│ ├── repo_level_evaluation.yaml
│ ├── estimate_difficulty.yaml
│ ├── classify_relevant_review_comment.yaml
│ ├── code_review_text_prompt.yaml
│ └── code_review_llm_evaluation.yaml
├── .pre-commit-config.yaml
├── LEGAL.md
├── AGENTS.md
├── pyproject.toml
├── .gitignore
├── CLAUDE.md
├── scripts
└── README.md
├── LICENSE
└── docs
└── questionnaire-demo.md
/.python-version:
--------------------------------------------------------------------------------
1 | 3.10
2 |
--------------------------------------------------------------------------------
/src/swe_care/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/swe_care/collect/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/swe_care/harness/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/swe_care/schema/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/swe_care/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/swe_care/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.12.0
5 | hooks:
6 | # Run the linter.
7 | - id: ruff-check
8 | args: [--fix]
9 | # Run the formatter.
10 | - id: ruff-format
11 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetLabels.graphql:
--------------------------------------------------------------------------------
1 | query GetLabels($prId: ID!, $cursor: String) {
2 | node(id: $prId) {
3 | ... on PullRequest {
4 | labels(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | name
12 | }
13 | }
14 | }
15 | }
16 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetIssueLabels.graphql:
--------------------------------------------------------------------------------
1 | query GetIssueLabels($issueId: ID!, $cursor: String) {
2 | node(id: $issueId) {
3 | ... on Issue {
4 | labels(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | name
12 | }
13 | }
14 | }
15 | }
16 | }
--------------------------------------------------------------------------------
/LEGAL.md:
--------------------------------------------------------------------------------
1 | Legal Disclaimer
2 |
3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4 |
5 | 法律免责声明
6 |
7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetThreadComments.graphql:
--------------------------------------------------------------------------------
1 | query GetThreadComments($threadId: ID!, $cursor: String) {
2 | node(id: $threadId) {
3 | ... on PullRequestReviewThread {
4 | comments(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | }
13 | }
14 | }
15 | }
16 | }
--------------------------------------------------------------------------------
/src/swe_care/schema/inference.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from dataclasses_json import dataclass_json
4 |
5 | from swe_care.schema.dataset import CodeReviewTaskInstance
6 |
7 |
8 | @dataclass_json
9 | @dataclass
10 | class CodeReviewInferenceInstance(CodeReviewTaskInstance):
11 | """Schema for code review inference instances."""
12 |
13 | text: str
14 | """The input text including instructions, the "Oracle"/RAG retrieved file, and an example of the patch format for output."""
15 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetIssueComments.graphql:
--------------------------------------------------------------------------------
1 | query GetIssueComments($issueId: ID!, $cursor: String) {
2 | node(id: $issueId) {
3 | ... on Issue {
4 | comments(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | author {
13 | login
14 | }
15 | body
16 | createdAt
17 | updatedAt
18 | }
19 | }
20 | }
21 | }
22 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetClosingIssues.graphql:
--------------------------------------------------------------------------------
1 | query GetClosingIssues($prId: ID!, $cursor: String) {
2 | node(id: $prId) {
3 | ... on PullRequest {
4 | closingIssuesReferences(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | number
13 | url
14 | title
15 | body
16 | state
17 | labels(first: 10) {
18 | totalCount
19 | pageInfo {
20 | hasNextPage
21 | endCursor
22 | }
23 | nodes {
24 | name
25 | }
26 | }
27 | }
28 | }
29 | }
30 | }
31 | }
--------------------------------------------------------------------------------
/src/swe_care/templates/rm_sample.yaml:
--------------------------------------------------------------------------------
1 | text: |-
2 | {% if relevant_files -%}
3 |
4 | {% for file_path, file_content in relevant_files.items() -%}
5 | [start of {{ file_path }}]
6 | {% if add_line_numbers -%}
7 | {% if file_content -%}
8 | {% for line in file_content.split('\n') -%}
9 | {{ "{:4d}".format(loop.index) }} {{ line }}
10 | {% endfor -%}
11 | {% endif -%}
12 | {% else -%}
13 | {{ file_content }}
14 | {% endif -%}
15 | [end of {{ file_path }}]
16 | {% endfor -%}
17 |
18 | {% endif -%}
19 | {% if diff_hunk -%}
20 |
21 | {{ diff_hunk }}
22 |
23 | {% endif -%}
24 | {{ path }}
25 | {{ line }}
26 |
27 | {{ review_comment }}
28 |
29 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetCommits.graphql:
--------------------------------------------------------------------------------
1 | query GetCommits($prId: ID!, $cursor: String) {
2 | node(id: $prId) {
3 | ... on PullRequest {
4 | commits(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | commit {
12 | oid
13 | message
14 | changedFilesIfAvailable
15 | authoredDate
16 | author {
17 | user {
18 | login
19 | }
20 | }
21 | committedDate
22 | committer {
23 | user {
24 | login
25 | }
26 | }
27 | parents(first: 2) {
28 | nodes {
29 | oid
30 | }
31 | }
32 | }
33 | }
34 | }
35 | }
36 | }
37 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetReviewThreads.graphql:
--------------------------------------------------------------------------------
1 | query GetReviewThreads($prId: ID!, $cursor: String) {
2 | node(id: $prId) {
3 | ... on PullRequest {
4 | reviewThreads(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | isResolved
13 | isOutdated
14 | isCollapsed
15 | path
16 | startLine
17 | originalLine
18 | diffSide
19 | startDiffSide
20 | resolvedBy {
21 | login
22 | }
23 | comments(first: 10) {
24 | totalCount
25 | pageInfo {
26 | hasNextPage
27 | endCursor
28 | }
29 | nodes {
30 | id
31 | }
32 | }
33 | }
34 | }
35 | }
36 | }
37 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetReviewComments.graphql:
--------------------------------------------------------------------------------
1 | query GetReviewComments($reviewId: ID!, $cursor: String) {
2 | node(id: $reviewId) {
3 | ... on PullRequestReview {
4 | comments(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | author {
13 | login
14 | }
15 | body
16 | createdAt
17 | updatedAt
18 | path
19 | diffHunk
20 | line
21 | startLine
22 | originalLine
23 | originalStartLine
24 | replyTo {
25 | id
26 | }
27 | isMinimized,
28 | minimizedReason,
29 | commit {
30 | oid
31 | }
32 | originalCommit {
33 | oid
34 | }
35 | }
36 | }
37 | }
38 | }
39 | }
--------------------------------------------------------------------------------
/src/swe_care/schema/evaluation.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any
3 |
4 | from dataclasses_json import dataclass_json
5 |
6 |
7 | @dataclass_json
8 | @dataclass
9 | class CodeReviewPrediction:
10 | """Schema for code review prediction instances."""
11 |
12 | instance_id: str
13 | """The instance ID of the code review task"""
14 | review_text: str
15 | """The prediction of the code review"""
16 | review_trajectory: list[str] | None = None
17 | """The trajectory of the code review, including the intermediate steps"""
18 |
19 |
20 | @dataclass_json
21 | @dataclass
22 | class EvaluatorResult:
23 | """Schema for evaluator result instances."""
24 |
25 | evaluator: str
26 | """The type of the evaluator"""
27 | evaluation: dict[str, Any]
28 | """The evaluation of the code review"""
29 |
30 |
31 | @dataclass_json
32 | @dataclass
33 | class CodeReviewEvaluationResult:
34 | """Schema for code review evaluation instances."""
35 |
36 | instance_id: str
37 | """The instance ID of the code review task"""
38 | score: float
39 | """The score of the code review"""
40 | evaluations: list[EvaluatorResult]
41 | """The evaluation of the code review"""
42 |
--------------------------------------------------------------------------------
/src/swe_care/templates/questionnaire/llm_respondent.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are an expert code reviewer with deep expertise in software engineering, code quality assessment, and security best practices. You will evaluate a code review task by completing a structured questionnaire.
3 |
4 | Your role is to:
5 | 1. Thoroughly analyze the provided code review context
6 | 2. Make objective, evidence-based assessments
7 | 3. Complete the questionnaire with precision and detail
8 | 4. Focus on technical accuracy rather than speculation
9 |
10 | Key principles:
11 | - Be thorough but concise in your explanations
12 | - Reference specific files, functions, and line numbers when discussing code
13 | - Consider both the immediate changes and their broader impact
14 | - Evaluate security, performance, maintainability, and correctness aspects
15 | - Be critical but fair in your assessments
16 | - For questions with multiple choice options, select the most appropriate answer
17 | - For text fields, provide clear and concise explanations with specific references to the code when relevant
18 |
19 | user: |-
20 | Here is the context for the code review task:
21 |
22 | {{ context }}
23 |
24 | Now, please complete {{ section_title }} of the questionnaire.
25 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetReviews.graphql:
--------------------------------------------------------------------------------
1 | query GetReviews($prId: ID!, $cursor: String) {
2 | node(id: $prId) {
3 | ... on PullRequest {
4 | reviews(first: 100, after: $cursor) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | author {
13 | login
14 | }
15 | state
16 | body
17 | submittedAt
18 | updatedAt
19 | commit {
20 | oid
21 | }
22 | comments(first: 100) {
23 | totalCount
24 | pageInfo {
25 | hasNextPage
26 | endCursor
27 | }
28 | nodes {
29 | id
30 | author {
31 | login
32 | }
33 | body
34 | createdAt
35 | updatedAt
36 | path
37 | diffHunk
38 | line
39 | startLine
40 | originalLine
41 | originalStartLine
42 | replyTo {
43 | id
44 | }
45 | isMinimized,
46 | minimizedReason,
47 | commit {
48 | oid
49 | }
50 | originalCommit {
51 | oid
52 | }
53 | }
54 | }
55 | }
56 | }
57 | }
58 | }
59 | }
--------------------------------------------------------------------------------
/src/swe_care/templates/classify_review_effort.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are an experienced software engineer responsible for estimating code review effort. Your task is to estimate how much effort would be required to review a code change on a scale of 1 to 5, where:
3 |
4 | 1 = Very Low Effort: Simple changes like typo fixes, minor documentation updates, or trivial formatting changes
5 | 2 = Low Effort: Small bug fixes, minor feature additions, or straightforward code changes affecting a few lines
6 | 3 = Medium Effort: Moderate complexity changes involving multiple files, standard feature implementations, or routine refactoring
7 | 4 = High Effort: Complex changes affecting multiple components, significant new features, or architectural modifications requiring careful review
8 | 5 = Very High Effort: Major architectural changes, complex algorithms, security-critical modifications, or changes requiring domain expertise
9 |
10 | Consider factors like:
11 | - Size and scope of the change
12 | - Complexity of the code modifications
13 | - Number of files affected
14 | - Potential impact on system behavior
15 | - Risk level of the changes
16 |
17 | Please respond with ONLY a single number from 1 to 5.
18 |
19 | user: |-
20 | Please estimate the review effort (1-5) for the following code change:
21 |
22 | **Pull Request Title:**
23 | {{ title }}
24 |
25 | **Pull Request Description:**
26 | {{ body }}
27 |
28 | **Commit Message:**
29 | {{ commit_message }}
30 |
31 | **Code Changes (Patch):**
32 | {{ patch }}
--------------------------------------------------------------------------------
/src/swe_care/templates/classify_problem_domain.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are an expert software engineer responsible for classifying software development tasks. Your task is to analyze a problem statement and classify it into exactly one of the following categories:
3 |
4 | 1. Bug Fixes: Resolving functional errors, crashes, incorrect outputs
5 | 2. New Feature Additions: Adding new functionality or features to the application
6 | 3. Code Refactoring / Architectural Improvement: Improving code structure, readability, maintainability without changing external behavior
7 | 4. Documentation Updates: Changes related to code comments or external documentation
8 | 5. Test Suite / CI Enhancements: Improving test coverage, test quality, or continuous integration processes
9 | 6. Performance Optimizations: Improving application speed, response time, or resource usage efficiency
10 | 7. Security Patches / Vulnerability Fixes: Fixing code defects that could lead to security issues
11 | 8. Dependency Updates & Env Compatibility: Updating third-party library dependencies or ensuring compatibility across different environments
12 | 9. Code Style, Linting, Formatting Fixes: Ensuring code complies with team coding standards and consistency
13 |
14 | Please respond with ONLY the category name exactly as listed above (e.g., "Bug Fixes", "New Feature Additions", etc.).
15 |
16 | user: |-
17 | Please classify the following problem statement into one of the predefined categories:
18 |
19 | Problem Statement:
20 | {{ problem_statement }}
21 |
22 | Category:
--------------------------------------------------------------------------------
/src/swe_care/templates/repo_level_evaluation.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | [ROLE] You are a senior code reviewer evaluating comment quality.
3 |
4 | [OBJECTIVE] Classify comments as high-quality (1) or low-quality (0) based on impact.
5 |
6 | [CONTEXT]
7 | - Problem statement, diff, and file context are provided.
8 | - Focus on whether the comment references changed code and suggests actionable fixes.
9 | - Think like a developer: would you feel compelled to modify your code in response? If yes, it's likely positive.
10 |
11 | [CRITERIA for label=1]
12 | ✓ Identifies concrete issue (bug, design, etc.) in the diff
13 | ✓ Provides specific, actionable fix suggestion
14 | ✓ References line numbers or code snippets
15 | ✓ Helps solve the PR's goal
16 |
17 | [CRITERIA for label=0]
18 | ✗ Generic ("LGTM"), vague, or social
19 | ✗ Off-topic or nitpicking without impact
20 | ✗ Incorrect or unhelpful
21 | ✗ Asks question without identifying flaw
22 |
23 | [OUTPUT]
24 | Return ONLY valid JSON:
25 | ```json
26 | {
27 | "label": <1 for positive/useful, 0 for negative/not-useful>,
28 | "reason": ""
29 | }
30 | ```
31 |
32 | user: |-
33 | Please analyze the following code review comment and classify it as positive (useful/high-quality) or negative (not useful/low-quality).
34 |
35 |
36 | {{ problem_statement }}
37 |
38 |
39 |
40 | {{ patch_to_review }}
41 |
42 |
43 |
44 | {{ formatted_review }}
45 |
46 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | # --- GraphQL Queries Directory ---
4 | GRAPHQL_QUERIES_DIR = Path(__file__).parent
5 |
6 | # --- GraphQL Queries Mapping ---
7 | GRAPHQL_QUERIES = {
8 | # Main query for fetching PRs with first page of nested data
9 | "merged_pull_requests": (
10 | GRAPHQL_QUERIES_DIR / "GetMergedPullRequests.graphql"
11 | ).read_text(),
12 | # Query for fetching additional pages of labels
13 | "labels": (GRAPHQL_QUERIES_DIR / "GetLabels.graphql").read_text(),
14 | # Query for fetching additional pages of commits
15 | "commits": (GRAPHQL_QUERIES_DIR / "GetCommits.graphql").read_text(),
16 | # Query for fetching additional pages of reviews
17 | "reviews": (GRAPHQL_QUERIES_DIR / "GetReviews.graphql").read_text(),
18 | # Query for fetching additional pages of review comments
19 | "review_comments": (GRAPHQL_QUERIES_DIR / "GetReviewComments.graphql").read_text(),
20 | # Query for fetching additional pages of closing issues
21 | "closing_issues": (GRAPHQL_QUERIES_DIR / "GetClosingIssues.graphql").read_text(),
22 | # Query for fetching additional pages of issue labels
23 | "issue_labels": (GRAPHQL_QUERIES_DIR / "GetIssueLabels.graphql").read_text(),
24 | # Query for fetching additional pages of issue comments
25 | "issue_comments": (GRAPHQL_QUERIES_DIR / "GetIssueComments.graphql").read_text(),
26 | # Query for fetching additional pages of review threads
27 | "review_threads": (GRAPHQL_QUERIES_DIR / "GetReviewThreads.graphql").read_text(),
28 | # Query for fetching additional pages of thread comments
29 | "thread_comments": (GRAPHQL_QUERIES_DIR / "GetThreadComments.graphql").read_text(),
30 | # Query for fetching additional pages of specific PR
31 | "specific_pr": (GRAPHQL_QUERIES_DIR / "GetSpecificPullRequest.graphql").read_text(),
32 | }
33 |
--------------------------------------------------------------------------------
/src/swe_care/templates/estimate_difficulty.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are an experienced software engineer responsible for estimating implementation difficulty. Your task is to estimate how difficult it would be to implement a pull request from scratch on a scale of 1 to 3, where:
3 |
4 | 1 = Low Difficulty: Simple implementations like typo fixes, minor configuration changes, straightforward bug fixes with clear solutions, basic feature additions using existing patterns, or routine maintenance tasks.
5 |
6 | 2 = Medium Difficulty: Moderate implementations requiring some problem-solving, such as bug fixes requiring investigation, feature additions involving multiple components, refactoring existing code, integration with existing APIs, or changes requiring understanding of business logic.
7 |
8 | 3 = High Difficulty: Complex implementations requiring significant technical expertise, such as architectural changes, performance optimizations, complex algorithms, new integrations with external systems, security-related fixes, or features requiring deep domain knowledge.
9 |
10 | Consider factors like:
11 | - Technical complexity of the problem being solved
12 | - Amount of new code vs. modifications to existing code
13 | - Number of systems/components involved
14 | - Domain knowledge required
15 | - Algorithm complexity
16 | - Architectural impact
17 | - Dependencies and integrations needed
18 | - Research or investigation required
19 |
20 | Please respond with ONLY a single number from 1 to 3.
21 |
22 | user: |-
23 | Please estimate the implementation difficulty level (1, 2, 3) for the following pull request:
24 |
25 | **Pull Request Title:**
26 | {{ title }}
27 |
28 | **Pull Request Description:**
29 | {{ body }}
30 |
31 | **Commit Message:**
32 | {{ commit_message }}
33 |
34 | **Changed Files:**
35 | {{ changed_files }}
36 |
37 | **Code Changes (Patch):**
38 | {{ patch }}
39 |
40 | Implementation Difficulty (1, 2, 3):
--------------------------------------------------------------------------------
/src/swe_care/templates/classify_relevant_review_comment.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are an expert code reviewer who determines whether review comments are likely to result in code changes. Your task is to classify each review comment as either "relevant" (likely to lead to code changes) or "irrelevant" (unlikely to lead to code changes).
3 |
4 | A comment is RELEVANT if it:
5 | - Requests specific code changes (e.g., "make this method static", "change variable name to be more descriptive")
6 | - Points out bugs, errors, or potential issues that need fixing
7 | - Suggests improvements to logic, algorithms, or data structures
8 | - Identifies missing functionality or edge cases that need to be handled
9 | - Requests changes to method signatures, parameters, or return types
10 | - Points out code that violates best practices or design patterns
11 |
12 | A comment is IRRELEVANT if it:
13 | - Is a simple approval or praise (e.g., "looks good", "LGTM", "nice", "thanks")
14 | - Only requests formatting changes with no impact on code logic (e.g., "fix indentation", "add spaces")
15 | - Requests adding tests (these don't change the code under review)
16 | - Asks for clarification or explanation without suggesting changes (e.g., "please explain", "what does this do?")
17 | - References previous comments that cannot be identified (e.g., "same as before", "see above")
18 | - Only requests adding comments or documentation (e.g., "add Javadoc", "document this method")
19 | - Is too vague to result in specific code changes
20 |
21 | Important: Focus on whether the comment would likely result in changes to the actual code logic, not just cosmetic changes or additions to test files.
22 |
23 | Respond with ONLY "relevant" or "irrelevant".
24 |
25 | user: |-
26 | Please classify the following code review comment:
27 |
28 | **Comment:**
29 | {{ comment_text }}
30 | {% if diff_hunk %}
31 |
32 | **Context (diff hunk):**
33 | ```diff
34 | {{ diff_hunk }}
35 | ```
36 | {% endif %}
37 | {% if path %}
38 | **File path:** {{ path }}
39 | {% if line %}**Line:** {{ line }}{% endif %}
40 | {% endif %}
41 |
42 | **Classification:**
43 |
--------------------------------------------------------------------------------
/src/swe_care/templates/code_review_text_prompt.yaml:
--------------------------------------------------------------------------------
1 | text: |-
2 | {% if files -%}
3 | You will be provided with 1) a partial code base, 2) an issue statement explaining a problem to resolve, and 3) a patch addressing the issue.
4 | {% else -%}
5 | You will be provided with 1) an issue statement explaining a problem to resolve, and 2) a patch addressing the issue.
6 | {% endif -%}
7 |
8 | {{ problem_statement }}
9 |
10 | {% if files -%}
11 |
12 | {% for file_path, file_content in files.items() -%}
13 | [start of {{ file_path }}]
14 | {% if add_line_numbers -%}
15 | {% if file_content -%}
16 | {% for line in file_content.split('\n') -%}
17 | {{ "{:4d}".format(loop.index) }} {{ line }}
18 | {% endfor -%}
19 | {% endif -%}
20 | {% else -%}
21 | {{ file_content }}
22 | {% endif -%}
23 | [end of {{ file_path }}]
24 | {% endfor -%}
25 |
26 | {% endif -%}
27 |
28 | {{ patch }}
29 |
30 | I need you to comprehensively review the above patch with respect to the given issue from multiple dimensions such as functional implementation, code quality, and defects. Generate a report in the following format.
31 |
32 | ## Function
33 | Briefly describe the main purpose and implemented functionality of this patch. Specifically, does this patch address the issue?
34 |
35 | ## Complexity
36 | Is the patch more complex than it should be? Check it at every level - lines, functions, classes.
37 |
38 | ## Style
39 | Does the patch follow the programming conventions of the original code? Does it pick good names for newly introduced variables, functions, classes, and files?
40 |
41 | ## Documentation
42 | Does the patch provide clear and necessary comments? If the patch changes how users build, test, interact with, or release code, does it also update associated documentation?
43 |
44 | ## Defects
45 | If there are any defects in the patch, point out their locations (file path and line number) and provide improvement suggestions in the following format:
46 |
47 | file_path: file path of defect 1 in the patch
48 | line: line number of defect 1 in the patch
49 | suggestion: suggestion for defect 1
50 |
51 |
52 | file_path: file path of defect 2 in the patch
53 | line: line number of defect 2 in the patch
54 | suggestion: suggestion for defect 2
55 |
56 | ...
57 | Note:
58 | - If there is no defect, output "None" in this section.
59 |
60 |
--------------------------------------------------------------------------------
/src/swe_care/templates/questionnaire/questionnaire.md.j2:
--------------------------------------------------------------------------------
1 | {#-
2 | This template assumes a Jinja2 rendering context variable named `instance`
3 | whose schema conforms to `CodeReviewTaskInstance` defined in `src/swe_care/schema/dataset.py`.
4 | -#}
5 |
6 | ### Code Review Annotation Questionnaire
7 |
8 | **Instance ID**: `{{ instance.instance_id }}`
9 | **Repository**: `{{ instance.repo }}`
10 | **Pull Request**: [`#{{ instance.pull_number }}`](https://github.com/{{ instance.repo }}/pull/{{ instance.pull_number }})
11 | **Language**: `{{ instance.language }}`
12 | **Created at**: `{{ instance.created_at }}`
13 | **Base commit**: `{{ instance.base_commit }}`
14 | **Commit to review**: `{{ instance.commit_to_review.head_commit }}`
15 | **Commit message**:
16 | > {{ instance.commit_to_review.head_commit_message | default('', true) | trim | replace('\n', '\n> ') }}
17 |
18 | ---
19 |
20 | ### Context
21 |
22 | #### Problem Statement
23 |
24 | > {{ instance.problem_statement | default('', true) | trim | replace('\n', '\n> ') }}
25 |
26 | #### Hints (pre-PR comments or linked issues context)
27 |
28 | > {{ instance.hints_text | default('', true) | trim | replace('\n', '\n> ') }}
29 |
30 | #### Resolved Issues
31 |
32 | {% if instance.resolved_issues %}
33 | {% for ri in instance.resolved_issues %}
34 |
35 | ##### Issue #{{ ri.number }}
36 |
37 | - Title:
38 |
39 | > {{ ri.title | default('', true) | trim | replace('\n', '\n> ') }}
40 |
41 | - Body:
42 |
43 | > {{ ri.body | default('', true) | trim | replace('\n', '\n> ') }}
44 | {% endfor %}
45 | {% else %}
46 |
47 | - None listed
48 | {% endif %}
49 |
50 | #### Patch to Review (diff)
51 |
52 | ```diff
53 | {{ instance.commit_to_review.patch_to_review }}
54 | ```
55 |
56 | #### Reference Review Comments (from the PR)
57 |
58 | Count: {{ instance.reference_review_comments | length }}
59 |
60 | {% for c in instance.reference_review_comments %}
61 | ---
62 |
63 | ##### Comment {{ loop.index }}
64 |
65 | - Review comment:
66 |
67 | > {{ c.text | default('', true) | trim | replace('\n', '\n> ') }}
68 |
69 | - Diff hunk the review comment applied to:
70 |
71 | ```diff
72 | {{ c.diff_hunk | default('', true) }}
73 | ```
74 |
75 | - Path: `{{ c.path }}`
76 | - Line: {{ c.line if c.line is not none else 'n/a' }} | Start line: {{ c.start_line if c.start_line is not none else 'n/a' }}
77 | - Original line: {{ c.original_line if c.original_line is not none else 'n/a' }} | Original start line: {{ c.original_start_line if c.original_start_line is not none else 'n/a' }}
78 | {% endfor %}
79 | {% if schema_md %}
80 | ---
81 |
82 | {{ schema_md }}
83 | {%- endif -%}
84 |
--------------------------------------------------------------------------------
/AGENTS.md:
--------------------------------------------------------------------------------
1 | # Repository Guidelines
2 |
3 | ## Project Structure & Module Organization
4 | - `src/swe_care/`: source package
5 | - `collect/` (GitHub data + dataset build CLIs)
6 | - `inference/` (text generation + API inference)
7 | - `harness/` (evaluation runners and evaluators)
8 | - `schema/`, `utils/`, `templates/` (data models, helpers, prompts)
9 | - `scripts/`: orchestration utilities (e.g., `run_eval_pipeline.py`)
10 | - `docs/`: documentation and demos
11 | - `results/`: local outputs created by commands (not tracked)
12 |
13 | ## Build, Test, and Development Commands
14 | - Install deps: `pip install uv && uv sync` (or `pip install -e .`)
15 | - Pre-commit hooks: `pre-commit install` • Run all: `pre-commit run --all-files`
16 | - Lint/format: `ruff check .` • `ruff format .`
17 | - Quick pipeline:
18 | `python scripts/run_eval_pipeline.py --dataset-file results/dataset/code_review_task_instances.jsonl --output-dir results/pipeline_output --model gpt-4o --model-provider openai --file-source oracle`
19 | - Module CLIs: `python -m swe_care.collect ...` | `python -m swe_care.inference ...` | `python -m swe_care.harness ...`
20 |
21 | ## Coding Style & Naming Conventions
22 | - Python ≥ 3.10 with type hints; prefer dataclasses for schemas.
23 | - Ruff (Black-like): 4-space indent, line length 88, double quotes (`pyproject.toml`).
24 | - snake_case for modules/functions/options; PascalCase for classes.
25 | - Keep functions cohesive; shared helpers in `utils/`; prompts in `templates/`.
26 | - Output naming examples: `___graphql_prs_data.jsonl`, `___rm_samples.jsonl`.
27 |
28 | ## Testing Guidelines
29 | - No traditional unit tests; validate via small, reproducible runs that write to `results/`.
30 | - Verify interfaces with help: `python -m swe_care.collect -h`.
31 | - Aim for determinism (fixed params, stable sorting). Include sample outputs in PRs when logic changes.
32 |
33 | ## Commit & Pull Request Guidelines
34 | - Conventional Commits: `feat:`, `fix:`, `docs:`, `refactor:`, `chore:`, `test:`; add scope (e.g., `inference:`).
35 | - PRs include: purpose, summary, example commands, sample outputs (paths under `results/`), and linked issues (e.g., `Closes #123`).
36 | - Keep PRs focused; document new flags/env vars in README and `--help`. Ensure pre-commit passes.
37 |
38 | ## Security & Configuration Tips
39 | - Never commit secrets. Use env vars: `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`; optional `OPENAI_BASE_URL`, `ANTHROPIC_BASE_URL`.
40 | - GitHub access via `--tokens`; watch rate limits.
41 | - Retrieval with Pyserini may require Java 21; set `--retrieval-output-dir` for temporary work.
42 |
--------------------------------------------------------------------------------
/src/swe_care/utils/prompt_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for loading and rendering YAML-based prompt templates.
3 | """
4 |
5 | from pathlib import Path
6 | from typing import Any
7 |
8 | import yaml
9 | from jinja2 import Environment, FileSystemLoader
10 |
11 | # Initialize Jinja2 environment
12 | _template_dir = Path(__file__).parent.parent / "templates"
13 | _jinja_env = Environment(loader=FileSystemLoader(_template_dir))
14 |
15 |
16 | def load_prompt(template_name: str, **context: Any) -> str | tuple[str, str]:
17 | """
18 | Load a YAML-based prompt template and render it with the provided context.
19 |
20 | Args:
21 | template_name: Name of the YAML template file (without extension) in the templates directory
22 | **context: Context variables to pass to the Jinja2 template
23 |
24 | Returns:
25 | - str: Rendered text if the template has only a 'text' key
26 | - tuple[str, str]: (system_prompt, user_prompt) if the template has 'system' and 'user' keys
27 | - str: Rendered user_prompt if the template has only a 'user' key
28 |
29 | Raises:
30 | FileNotFoundError: If the template file doesn't exist
31 | ValueError: If the template structure is invalid
32 | """
33 | template_path = _template_dir / f"{template_name}.yaml"
34 |
35 | if not template_path.exists():
36 | raise FileNotFoundError(f"Template file not found: {template_path}")
37 |
38 | # Load YAML content
39 | with open(template_path) as f:
40 | template_data = yaml.safe_load(f)
41 |
42 | if not isinstance(template_data, dict):
43 | raise ValueError(f"Invalid template structure in {template_name}.yaml")
44 |
45 | # Handle different template structures
46 | if "text" in template_data:
47 | # Simple text template
48 | template = _jinja_env.from_string(template_data["text"])
49 | return template.render(**context)
50 |
51 | elif "user" in template_data:
52 | # User prompt with optional system prompt
53 | user_template = _jinja_env.from_string(template_data["user"])
54 | user_prompt = user_template.render(**context)
55 |
56 | if "system" in template_data:
57 | # Both system and user prompts
58 | system_template = _jinja_env.from_string(template_data["system"])
59 | system_prompt = system_template.render(**context)
60 | return (system_prompt, user_prompt)
61 | else:
62 | # Only user prompt
63 | return user_prompt
64 |
65 | else:
66 | raise ValueError(
67 | f"Template {template_name}.yaml must contain either 'text' or 'user' key"
68 | )
69 |
--------------------------------------------------------------------------------
/src/swe_care/utils/read_file.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Tuple
3 |
4 | import chardet
5 |
6 | default_encoding = "utf-8"
7 | common_encodings = (
8 | "utf-8",
9 | "utf-16",
10 | "latin-1",
11 | "ascii",
12 | "windows-1252",
13 | "cp1251",
14 | "cp1253",
15 | "cp1254",
16 | "cp1255",
17 | "cp1256",
18 | "shift_jis",
19 | "big5",
20 | "gb2312",
21 | )
22 |
23 |
24 | def detect_file_encoding(file_path: Path) -> str:
25 | """Function to detect encoding"""
26 | # Read the file as binary data
27 | raw_data = file_path.read_bytes()
28 | # Detect encoding
29 | detected = chardet.detect(raw_data)
30 | encoding = detected["encoding"]
31 | return encoding
32 |
33 |
34 | def read_file_with_encodings(file_path: Path, encodings: Tuple[str]) -> Tuple[str, str]:
35 | """Attempt to read a file using various encodings, return content if successful"""
36 | for encoding in encodings:
37 | try:
38 | content = file_path.read_text(encoding=encoding)
39 | return content, encoding
40 | except (UnicodeDecodeError, TypeError, ValueError, UnicodeError):
41 | continue
42 | raise ValueError(
43 | f"Could not read file with any of the provided encodings: {encodings}"
44 | )
45 |
46 |
47 | def read_file_with_limit(content: str, max_lines_to_read: int = None) -> str:
48 | """Helper function to return a limited number of lines from the content."""
49 | if max_lines_to_read is None:
50 | return content
51 | else:
52 | return "\n".join(content.splitlines()[:max_lines_to_read])
53 |
54 |
55 | def read_file_to_string(file_path: Path | str, *, max_lines_to_read: int = None) -> str:
56 | """Function to detect encoding and read file to string with an optional line limit."""
57 | if isinstance(file_path, str):
58 | file_path = Path(file_path)
59 |
60 | try:
61 | content, _ = read_file_with_encodings(file_path, (default_encoding,))
62 | return read_file_with_limit(content, max_lines_to_read)
63 | except ValueError:
64 | pass
65 |
66 | try:
67 | detected_encoding = detect_file_encoding(file_path)
68 | # Read the file with the detected encoding
69 | content, _ = read_file_with_encodings(file_path, (detected_encoding,))
70 | return read_file_with_limit(content, max_lines_to_read)
71 | except ValueError:
72 | pass
73 |
74 | try:
75 | content, _ = read_file_with_encodings(file_path, common_encodings)
76 | return read_file_with_limit(content, max_lines_to_read)
77 | except ValueError:
78 | pass
79 |
80 | raise ValueError(f"Could not read file: {file_path}")
81 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "swe-care"
3 | version = "0.2.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | authors = []
7 | requires-python = ">=3.10"
8 | dependencies = [
9 | "anthropic>=0.55.0",
10 | "dataclasses-json>=0.6.7",
11 | "jinja2>=3.1.0",
12 | "loguru>=0.7.2",
13 | "openai>=1.90.0",
14 | "requests[socks]>=2.31.0",
15 | "tenacity>=9.0.0",
16 | "tiktoken>=0.9.0",
17 | "tqdm>=4.66.0",
18 | "unidiff>=0.7.5",
19 | "nltk>=3.9.1",
20 | "rank_bm25>=0.2.2",
21 | "gitpython>=3.1.45",
22 | "jedi>=0.19.2",
23 | "pyserini>=1.2.0",
24 | "httpx[socks]>=0.28.1",
25 | "chardet>=5.2.0",
26 | "jinja2>=3.1.6",
27 | "tree-sitter>=0.25.1",
28 | "tree-sitter-python>=0.25.0",
29 | "datasets>=4.2.0",
30 | ]
31 |
32 | [project.scripts]
33 | swe-care = "swe_care.collect:main"
34 |
35 | [build-system]
36 | requires = ["hatchling"]
37 | build-backend = "hatchling.build"
38 |
39 | [dependency-groups]
40 | dev = [
41 | "pre-commit",
42 | "ruff",
43 | ]
44 |
45 | [tool.ruff]
46 | # Exclude a variety of commonly ignored directories.
47 | exclude = [
48 | ".bzr",
49 | ".direnv",
50 | ".eggs",
51 | ".git",
52 | ".git-rewrite",
53 | ".hg",
54 | ".ipynb_checkpoints",
55 | ".mypy_cache",
56 | ".nox",
57 | ".pants.d",
58 | ".pyenv",
59 | ".pytest_cache",
60 | ".pytype",
61 | ".ruff_cache",
62 | ".svn",
63 | ".tox",
64 | ".venv",
65 | ".vscode",
66 | "__pypackages__",
67 | "_build",
68 | "buck-out",
69 | "build",
70 | "dist",
71 | "node_modules",
72 | "site-packages",
73 | "venv",
74 | ]
75 |
76 | # Same as Black.
77 | line-length = 88
78 | indent-width = 4
79 |
80 | # Assume Python 3.10
81 | target-version = "py310"
82 |
83 | [tool.ruff.lint]
84 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
85 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
86 | # McCabe complexity (`C901`) by default.
87 | select = ["E4", "E7", "E9", "F"]
88 | ignore = []
89 |
90 | # Allow fix for all enabled rules (when `--fix`) is provided.
91 | fixable = ["ALL"]
92 | unfixable = []
93 |
94 | # Allow unused variables when underscore-prefixed.
95 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
96 |
97 | [tool.ruff.format]
98 | # Like Black, use double quotes for strings.
99 | quote-style = "double"
100 |
101 | # Like Black, indent with spaces, rather than tabs.
102 | indent-style = "space"
103 |
104 | # Like Black, respect magic trailing commas.
105 | skip-magic-trailing-comma = false
106 |
107 | # Like Black, automatically detect the appropriate line ending.
108 | line-ending = "auto"
109 |
110 | [[tool.uv.index]]
111 | name = "default"
112 | url = "https://pypi.tuna.tsinghua.edu.cn/simple"
113 |
--------------------------------------------------------------------------------
/src/swe_care/templates/code_review_llm_evaluation.yaml:
--------------------------------------------------------------------------------
1 | system: |-
2 | You are a code review evaluator. Your task is to evaluate the quality of the code review. You need to evaluate the code reviews based on the quality attributes of a standard code review, which are shown below:
3 |
4 | - Functionality: An evaluation of whether the main purpose of the patch, its functionality, and any potential functional or security defects have been described.
5 | - Quality: An evaluation of the accuracy of code quality descriptions, including patch complexity (line-level, function-level, class-level, file-level), code readability, optimization status, and maintainability, etc.
6 | - Style: An evaluation of whether the patch follows the programming conventions of the original code, e.g. the naming of variables and functions.
7 | - Documentation: An evaluation of whether the patch provide clear and necessary comments, as well as documentation.
8 |
9 | For each field, you should analyze:
10 |
11 | - Correctness: whether the review is technically correct and contains no factual errors with regard to the provided issue, code base, and patch.
12 | - Relevance: whether the review is targeted at the issue and the code patch.
13 | - Clarity: whether the review is clear and without redundant information.
14 | - Consistency: whether the review is logically consistent with the issue, code base, patch, and other fields in the review.
15 | - Language: whether the review uses professional language and contains no grammatical errors. Whether it facilitate the knowledge transfer, expresses in a kind way and provides positive feedback.
16 |
17 | Give a score between 0 and 1 (inclusive) to each of these five dimensions, and output your final evaluation in nested json format:
18 | ```
19 | {
20 | "function": {"correctness": score, "relevance": score, "clarity": score, "consistency": score, "language": score},
21 | "quality": {"correctness": score, "relevance": score, "clarity": score, "consistency": score, "language": score},
22 | "style": {"correctness": score, "relevance": score, "clarity": score, "consistency": score, "language": score},
23 | "documentation": {"correctness": score, "relevance": score, "clarity": score, "consistency": score, "language": score}
24 | }
25 | ```
26 |
27 | If you cannot identify a certain field from the review, give a 0 score to all dimensions in this field.
28 |
29 | user: |-
30 | Please evaluate the following code review based on the provided context and review text.
31 |
32 |
33 | Title: {{ input.title }}
34 | Repository: {{ input.repo }}
35 |
36 |
37 |
38 | {{ input.problem_statement }}
39 |
40 |
41 |
42 | Commit: {{ input.commit_to_review.head_commit }}
43 | Commit Message: {{ input.commit_to_review.head_commit_message }}
44 |
45 |
46 |
47 | {{ input.commit_to_review.patch_to_review }}
48 |
49 |
50 |
51 | {{ prediction.review_text }}
52 |
53 |
--------------------------------------------------------------------------------
/src/swe_care/collect/get_top_repos.py:
--------------------------------------------------------------------------------
1 | """
2 | Fetch top repositories for a given language and save to JSONL file.
3 | """
4 |
5 | import json
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | from loguru import logger
10 |
11 | from swe_care.utils.github import GitHubAPI
12 |
13 |
14 | def get_top_repos(
15 | language: str,
16 | top_n: int,
17 | output_dir: Path | str,
18 | tokens: Optional[list[str]] = None,
19 | ) -> None:
20 | """
21 | Fetch top repositories for a given language and save to JSONL file.
22 |
23 | Args:
24 | language: Programming language to search for
25 | top_n: Number of top repositories to fetch
26 | output_dir: Directory to save the output file
27 | tokens: Optional list of GitHub tokens for API requests
28 | """
29 | if isinstance(output_dir, str):
30 | output_dir = Path(output_dir)
31 | output_dir.mkdir(parents=True, exist_ok=True)
32 | output_file = output_dir / f"repos_top_{top_n}_{language}.jsonl"
33 |
34 | # Create GitHub API instance
35 | github_api = GitHubAPI(tokens=tokens)
36 |
37 | repos_collected = 0
38 | page = 1
39 |
40 | with open(output_file, "w") as f:
41 | while repos_collected < top_n:
42 | remaining_results = top_n - repos_collected
43 | per_page = min(100, remaining_results) # GitHub API max is 100 per page
44 |
45 | params = {
46 | "q": f"language:{language}",
47 | "sort": "stars",
48 | "order": "desc",
49 | "per_page": per_page,
50 | "page": page,
51 | }
52 |
53 | logger.info(f"Fetching page {page} with {per_page} results...")
54 |
55 | try:
56 | response = github_api.call_api("search/repositories", params=params)
57 | data = response.json()
58 | items = data.get("items", [])
59 |
60 | if not items:
61 | logger.info(
62 | f"No more repositories found. Collected {repos_collected} repositories."
63 | )
64 | break
65 |
66 | for repo in items:
67 | if repos_collected >= top_n:
68 | break
69 |
70 | repo_data = {
71 | "name": repo["full_name"],
72 | "stars": repo["stargazers_count"],
73 | "url": repo["html_url"],
74 | "description": repo.get("description", ""),
75 | "owner": repo["owner"]["login"],
76 | "language": repo.get("language", ""),
77 | }
78 |
79 | f.write(json.dumps(repo_data) + "\n")
80 | repos_collected += 1
81 |
82 | logger.info(f"Collected {repos_collected}/{top_n} repositories")
83 |
84 | page += 1
85 |
86 | except Exception as e:
87 | logger.error(f"Error fetching repositories: {e}")
88 | break
89 |
90 | logger.success(
91 | f"Successfully saved {repos_collected} repositories to {output_file}"
92 | )
93 |
--------------------------------------------------------------------------------
/src/swe_care/harness/evaluators/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for evaluators.
3 | """
4 |
5 | import re
6 | from typing import Optional
7 |
8 | import unidiff
9 |
10 | from swe_care.schema.dataset import (
11 | CodeReviewTaskInstance,
12 | ReferenceReviewComment,
13 | )
14 |
15 |
16 | def _parse_diff_hunks(diff_text: str) -> list[dict]:
17 | """Parse diff hunks from a patch text."""
18 | patch = unidiff.PatchSet(diff_text)
19 | results = []
20 |
21 | for patched_file in patch:
22 | old_path = patched_file.path
23 | for hunk in patched_file:
24 | hunk_info = {
25 | "file": old_path,
26 | "old_start": hunk.source_start,
27 | "old_lines": hunk.source_length,
28 | "old_end": hunk.source_start + hunk.source_length - 1,
29 | "new_start": hunk.target_start,
30 | "new_lines": hunk.target_length,
31 | "new_end": hunk.target_start + hunk.target_length - 1,
32 | "hunk": str(hunk),
33 | }
34 | results.append(hunk_info)
35 | return results
36 |
37 |
38 | def extract_defects_from_review(
39 | review_text: str, input: Optional[CodeReviewTaskInstance] = None
40 | ) -> list[ReferenceReviewComment]:
41 | """Extract defects from review text that follows the REVIEW_PROMPT format.
42 |
43 | Args:
44 | review_text: The code review text containing defects in tags
45 | input: Optional input task instance for extracting diff hunks
46 |
47 | Returns:
48 | List of ReferenceReviewComment objects extracted from the review text
49 | """
50 | defects = []
51 |
52 | # Pattern to match blocks
53 | defect_pattern = r"\s*(.*?)\s*"
54 | defect_matches = re.findall(defect_pattern, review_text, re.DOTALL)
55 |
56 | for defect_content in defect_matches:
57 | # Extract file_path, line, and suggestion from defect content
58 | file_path_match = re.search(r"file_path:\s*(.+)", defect_content)
59 | line_match = re.search(r"line:\s*(\d+)", defect_content)
60 | suggestion_match = re.search(r"suggestion:\s*(.+)", defect_content, re.DOTALL)
61 |
62 | if file_path_match and suggestion_match:
63 | file_path = file_path_match.group(1).strip()
64 | line_num = int(line_match.group(1)) if line_match else None
65 | suggestion = suggestion_match.group(1).strip()
66 |
67 | # Extract diff hunk based on patch_to_review, file path line number
68 | # if the line number falls within a hunk, use that hunk.
69 | patch = (
70 | input.commit_to_review.patch_to_review
71 | if input and input.commit_to_review
72 | else None
73 | )
74 | diff_hunks = _parse_diff_hunks(patch) if patch else []
75 | diff_hunk = None
76 |
77 | for hunk in diff_hunks:
78 | if (
79 | hunk["file"] == file_path
80 | and line_num is not None
81 | and hunk["new_start"] <= line_num <= hunk["new_end"]
82 | ):
83 | diff_hunk = hunk["hunk"]
84 | break
85 |
86 | defect = ReferenceReviewComment(
87 | text=suggestion,
88 | path=file_path,
89 | diff_hunk=diff_hunk,
90 | line=line_num,
91 | start_line=None,
92 | original_line=None,
93 | original_start_line=None,
94 | )
95 | defects.append(defect)
96 |
97 | return defects
98 |
--------------------------------------------------------------------------------
/src/swe_care/schema/collect.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Literal
3 |
4 | from dataclasses_json import dataclass_json
5 |
6 | from swe_care.schema.dataset import ReferenceReviewComment
7 |
8 |
9 | @dataclass_json
10 | @dataclass
11 | class ReviewCommentLabels:
12 | """Labels for a review comment."""
13 |
14 | referenced_line_changed_in_merged_commit: bool
15 | """Whether the referenced line was changed in the merged commit. If True, the review comment was more likely to address real issues that got fixed."""
16 | is_resolved: bool
17 | """Whether the review thread was resolved"""
18 | is_outdated: bool
19 | """Whether the review thread is outdated"""
20 | is_collapsed: bool
21 | """Whether the review thread is collapsed"""
22 | marked_as_dismissed: bool
23 | """Whether the review comment was marked as dismissed (minimized for reasons other than being resolved)"""
24 |
25 |
26 | @dataclass_json
27 | @dataclass
28 | class LabeledReviewComment(ReferenceReviewComment):
29 | """Schema for labeled review comment instances."""
30 |
31 | labels: ReviewCommentLabels
32 | """Labels for the review comment"""
33 |
34 |
35 | @dataclass_json
36 | @dataclass
37 | class CommitClassificationResult:
38 | """Schema combining commit evaluation and labeled review comments."""
39 |
40 | commit_sha: str
41 | """The commit SHA"""
42 | labeled_review_comments: list[LabeledReviewComment]
43 | """List of labeled review comments for this commit"""
44 | total_score: float
45 | """Total evaluation score for the commit"""
46 | rule_results: dict[str, bool | float]
47 | """Results from evaluation rules"""
48 | patch: str
49 | """Patch content between base commit and this commit"""
50 |
51 |
52 | @dataclass_json
53 | @dataclass
54 | class PRClassification:
55 | """Schema for PR with combined commit classification data."""
56 |
57 | repo_owner: str
58 | """Repository owner"""
59 | repo_name: str
60 | """Repository name"""
61 | pr_number: int
62 | """Pull request number"""
63 | url: str
64 | """Pull request URL"""
65 | commits: list[CommitClassificationResult]
66 | """List of commits with classification data (evaluation + labeled review comments)"""
67 |
68 |
69 | @dataclass_json
70 | @dataclass
71 | class RewardModelTrainingSampleMetadata:
72 | """Metadata for reward model training sample instances."""
73 |
74 | repo: str
75 | """Repository in format 'owner/name'"""
76 | pr_number: int
77 | """Pull request number"""
78 | url: str
79 | """Pull request URL"""
80 | commit_to_review: str
81 | """Commit SHA being reviewed"""
82 | file_source: Literal[
83 | "none",
84 | "base_changed_files",
85 | "reviewed_file",
86 | "retrieved_base_changed_files",
87 | "retrieved_all_files",
88 | ]
89 | """Source for file content ('none', 'base_changed_files', 'reviewed_file', 'retrieved_base_changed_files', or 'retrieved_all_files')"""
90 |
91 |
92 | @dataclass_json
93 | @dataclass
94 | class RewardModelTrainingSample:
95 | """Schema for reward model training sample instances."""
96 |
97 | problem_statement: str
98 | """The problem statement extracted from closing issues or PR description"""
99 | patch_to_review: str
100 | """The patch content to be reviewed"""
101 | pos_review: list[str]
102 | """List of positive review comments (referenced_line_changed_in_merged_commit=True and is_resolved=True)"""
103 | neg_review: list[str]
104 | """List of negative review comments (all others)"""
105 | metadata: RewardModelTrainingSampleMetadata
106 | """Metadata about the sample"""
107 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | # nuclear option because steven uses PyCharm.
161 | .idea/
162 |
163 | # some project specific stuff
164 | results/
165 | logs/
166 | dev/
167 |
168 | # data files
169 | *.jsonl
170 | *.json
171 | !src/**/*.json
172 |
173 | # Jupyter Notebooks
174 | *.ipynb
175 |
176 | .DS_Store
177 |
178 | .serena/
179 |
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
1 | # CLAUDE.md
2 |
3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4 |
5 | ## Project Overview
6 |
7 | SWE-CARE is a comprehensive benchmark for evaluating Large Language Models (LLMs) on software engineering tasks, with a focus on code analysis, review, and issue-resolving capabilities. The project currently supports Python and Java.
8 |
9 | The benchmark features two main task types:
10 |
11 | 1. **Issue Resolving**: Generate code patches to fix GitHub issues
12 | 2. **Code Review**: Generate comprehensive code review reports for code diffs
13 |
14 | ## Commands
15 |
16 | ### Development Setup
17 |
18 | ```bash
19 | # Install dependencies using uv (recommended)
20 | pip install uv
21 | uv sync
22 |
23 | # Or using pip
24 | pip install -e .
25 |
26 | # Install pre-commit hooks (for development)
27 | pre-commit install
28 | ```
29 |
30 | ### Linting
31 |
32 | ```bash
33 | # Run ruff linter (configured in pyproject.toml)
34 | ruff check .
35 |
36 | # Run ruff formatter
37 | ruff format .
38 |
39 | # Pre-commit runs both automatically
40 | pre-commit run --all-files
41 | ```
42 |
43 | ### Running Tests
44 |
45 | Note: This project doesn't have traditional unit tests. Instead, it focuses on data collection, inference, and evaluation scripts.
46 |
47 | ## High-Level Architecture
48 |
49 | ### Core Modules
50 |
51 | 1. **`src/swe_care/collect/`** - Data collection pipeline
52 | - `get_top_repos.py` - Find most starred repos by language
53 | - `get_graphql_prs_data.py` - Fetch PR data via GitHub GraphQL API
54 | - `classify_prs_data.py` - Analyze commits and label review comments
55 | - `build_code_review_dataset.py` - Build final dataset with LLM-classified metadata
56 | - `convert_to_rm_samples.py` - Convert to reward model training samples
57 |
58 | 2. **`src/swe_care/inference/`** - LLM inference pipeline
59 | - `create_code_review_text.py` - Generate text datasets with different context strategies
60 | - `run_api.py` - Run LLM inference on code review tasks
61 |
62 | 3. **`src/swe_care/harness/`** - Evaluation framework
63 | - `code_review_eval.py` - Evaluate model predictions using rule-based or LLM-based evaluators
64 |
65 | 4. **`src/swe_care/schema/`** - Data models
66 | - `dataset.py` - Core task instance schemas (IssueResolvingTaskInstance, CodeReviewTaskInstance)
67 | - `collect.py` - GitHub PR data schemas
68 | - `inference.py` - Inference input/output schemas
69 | - `evaluation.py` - Evaluation result schemas
70 |
71 | 5. **`src/swe_care/utils/`** - Utility functions
72 | - `github.py` - GitHub API interactions
73 | - `llm_models/clients.py` - LLM API clients (OpenAI, Anthropic, etc.)
74 | - `bm25_retrieval.py` - BM25-based file retrieval
75 | - `patch.py` - Patch file manipulation
76 |
77 | ### Key Patterns
78 |
79 | - **Modular CLI**: Each module (`collect`, `inference`, `harness`) has its own `__main__.py` with subcommands
80 | - **Schema-driven**: All data structures use dataclasses with JSON serialization
81 | - **Parallel Processing**: Most operations support `--jobs` for concurrent execution
82 | - **GitHub API Token Management**: Supports multiple tokens for rate limit handling
83 |
84 | ### Data Flow
85 |
86 | 1. **Collection**: GitHub repos → PR data → Classified PRs → Code review dataset
87 | 2. **Inference**: Dataset → Text generation → LLM predictions
88 | 3. **Evaluation**: Predictions + Dataset → Evaluation results
89 |
90 | ## Important Considerations
91 |
92 | - **GitHub API Rate Limits**: Always provide GitHub tokens via `--tokens` parameter
93 | - **LLM API Keys**: Set environment variables (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
94 | - **Large Files**: Be careful with retrieval operations on large repositories
95 | - **Parallel Jobs**: Adjust `--jobs` based on API rate limits and system resources
96 |
97 | ## Environment Variables
98 |
99 | - `OPENAI_API_KEY` - OpenAI API key for GPT models
100 | - `ANTHROPIC_API_KEY` - Anthropic API key for Claude models
101 | - `OPENAI_BASE_URL` - Custom OpenAI-compatible API endpoint
102 | - `ANTHROPIC_BASE_URL` - Custom Anthropic-compatible API endpoint
103 |
--------------------------------------------------------------------------------
/src/swe_care/harness/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Optional
3 |
4 | from loguru import logger
5 |
6 |
7 | class Evaluator(ABC):
8 | """Grade, tag, or otherwise evaluate predictions relative to their inputs
9 | and/or reference labels."""
10 |
11 | @property
12 | def evaluation_name(self) -> str:
13 | """The name of the evaluation."""
14 | return self.__class__.__name__
15 |
16 | @property
17 | def requires_reference(self) -> bool:
18 | """Whether this evaluator requires a reference label."""
19 | return False
20 |
21 | @property
22 | def requires_input(self) -> bool:
23 | """Whether this evaluator requires an input."""
24 | return False
25 |
26 | @property
27 | def _skip_input_warning(self) -> str:
28 | """Warning to show when input is ignored."""
29 | return f"Ignoring input in {self.__class__.__name__}, as it is not expected."
30 |
31 | @property
32 | def _skip_reference_warning(self) -> str:
33 | """Warning to show when reference is ignored."""
34 | return (
35 | f"Ignoring reference in {self.__class__.__name__}, as it is not expected."
36 | )
37 |
38 | def _check_evaluation_args(
39 | self,
40 | reference: Optional[Any] = None,
41 | input: Optional[Any] = None,
42 | ) -> None:
43 | """Check if the evaluation arguments are valid.
44 |
45 | Args:
46 | reference (Optional[Any], optional): The reference label.
47 | input (Optional[Any], optional): The input.
48 | Raises:
49 | ValueError: If the evaluator requires an input but none is provided,
50 | or if the evaluator requires a reference label but none is provided.
51 | """
52 | if self.requires_input and input is None:
53 | raise ValueError(f"{self.__class__.__name__} requires an input.")
54 | elif input is not None and not self.requires_input:
55 | logger.warning(self._skip_input_warning)
56 | if self.requires_reference and reference is None:
57 | raise ValueError(f"{self.__class__.__name__} requires a reference.")
58 | elif reference is not None and not self.requires_reference:
59 | logger.warning(self._skip_reference_warning)
60 |
61 | @abstractmethod
62 | def _evaluate(
63 | self,
64 | *,
65 | prediction: Any,
66 | reference: Optional[Any] = None,
67 | input: Optional[Any] = None,
68 | **kwargs: Any,
69 | ) -> dict:
70 | """Evaluate Chain or LLM output, based on optional input and label.
71 |
72 | Args:
73 | prediction (Any): The LLM or chain prediction to evaluate.
74 | reference (Optional[Any], optional): The reference label to evaluate against.
75 | input (Optional[Any], optional): The input to consider during evaluation.
76 | kwargs: Additional keyword arguments, including callbacks, tags, etc.
77 | Returns:
78 | dict: The evaluation results containing the score or value.
79 | It is recommended that the dictionary contain the following keys:
80 | - score: the score of the evaluation, if applicable.
81 | - value: the string value of the evaluation, if applicable.
82 | - reasoning: the reasoning for the evaluation, if applicable.
83 | """ # noqa: E501
84 |
85 | def evaluate(
86 | self,
87 | *,
88 | prediction: Any,
89 | reference: Optional[Any] = None,
90 | input: Optional[Any] = None,
91 | **kwargs: Any,
92 | ) -> dict:
93 | """Evaluate Chain or LLM output, based on optional input and label.
94 |
95 | Args:
96 | prediction (Any): The LLM or chain prediction to evaluate.
97 | reference (Optional[Any], optional): The reference label to evaluate against.
98 | input (Optional[Any], optional): The input to consider during evaluation.
99 | kwargs: Additional keyword arguments, including callbacks, tags, etc.
100 | Returns:
101 | dict: The evaluation results containing the score or value.
102 | """ # noqa: E501
103 | self._check_evaluation_args(reference=reference, input=input)
104 | return self._evaluate(
105 | prediction=prediction, reference=reference, input=input, **kwargs
106 | )
107 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetSpecificPullRequest.graphql:
--------------------------------------------------------------------------------
1 | query GetSpecificPullRequest($owner: String!, $name: String!, $prNumber: Int!) {
2 | repository(owner: $owner, name: $name) {
3 | nameWithOwner
4 | pullRequest(
5 | number: $prNumber
6 | ) {
7 | id
8 | title
9 | body
10 | number
11 | url
12 | author {
13 | login
14 | }
15 | createdAt
16 | mergedAt
17 | mergedBy {
18 | login
19 | }
20 | baseRefOid
21 | baseRefName
22 | headRefOid
23 | headRefName
24 | changedFiles
25 | labels(first: 10) {
26 | totalCount
27 | pageInfo {
28 | hasNextPage
29 | endCursor
30 | }
31 | nodes {
32 | name
33 | }
34 | }
35 | commits(first: 10) {
36 | totalCount
37 | pageInfo {
38 | hasNextPage
39 | endCursor
40 | }
41 | nodes {
42 | commit {
43 | oid
44 | message
45 | changedFilesIfAvailable
46 | authoredDate
47 | author {
48 | user {
49 | login
50 | }
51 | }
52 | committedDate
53 | committer {
54 | user {
55 | login
56 | }
57 | }
58 | parents(first: 2) {
59 | nodes {
60 | oid
61 | }
62 | }
63 | }
64 | }
65 | }
66 | reviews(first: 10) {
67 | totalCount
68 | pageInfo {
69 | hasNextPage
70 | endCursor
71 | }
72 | nodes {
73 | id
74 | author {
75 | login
76 | }
77 | state
78 | body
79 | submittedAt
80 | updatedAt
81 | commit {
82 | oid
83 | }
84 | comments(first: 10) {
85 | totalCount
86 | pageInfo {
87 | hasNextPage
88 | endCursor
89 | }
90 | nodes {
91 | id
92 | author {
93 | login
94 | }
95 | body
96 | createdAt
97 | updatedAt
98 | path
99 | diffHunk
100 | line
101 | startLine
102 | originalLine
103 | originalStartLine
104 | replyTo {
105 | id
106 | }
107 | isMinimized,
108 | minimizedReason,
109 | commit {
110 | oid
111 | }
112 | originalCommit {
113 | oid
114 | }
115 | }
116 | }
117 | }
118 | }
119 | reviewThreads(first: 10) {
120 | totalCount
121 | pageInfo {
122 | hasNextPage
123 | endCursor
124 | }
125 | nodes {
126 | id
127 | isResolved
128 | isOutdated
129 | isCollapsed
130 | path
131 | startLine
132 | originalLine
133 | diffSide
134 | startDiffSide
135 | resolvedBy {
136 | login
137 | }
138 | comments(first: 10) {
139 | totalCount
140 | pageInfo {
141 | hasNextPage
142 | endCursor
143 | }
144 | nodes {
145 | id
146 | }
147 | }
148 | }
149 | }
150 | closingIssuesReferences(first: 10) {
151 | totalCount
152 | pageInfo {
153 | hasNextPage
154 | endCursor
155 | }
156 | nodes {
157 | id
158 | number
159 | url
160 | title
161 | body
162 | state
163 | labels(first: 10) {
164 | totalCount
165 | pageInfo {
166 | hasNextPage
167 | endCursor
168 | }
169 | nodes {
170 | name
171 | }
172 | }
173 | comments(first: 10) {
174 | totalCount
175 | pageInfo {
176 | hasNextPage
177 | endCursor
178 | }
179 | nodes {
180 | id
181 | author {
182 | login
183 | }
184 | body
185 | createdAt
186 | updatedAt
187 | }
188 | }
189 | }
190 | }
191 | }
192 | }
193 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/github_graphql/GetMergedPullRequests.graphql:
--------------------------------------------------------------------------------
1 | query GetMergedPullRequests($owner: String!, $name: String!, $prCursor: String, $maxNumber: Int!) {
2 | repository(owner: $owner, name: $name) {
3 | nameWithOwner
4 | pullRequests(states: [MERGED], first: $maxNumber, after: $prCursor, orderBy: {field: CREATED_AT, direction: DESC}) {
5 | totalCount
6 | pageInfo {
7 | hasNextPage
8 | endCursor
9 | }
10 | nodes {
11 | id
12 | title
13 | body
14 | number
15 | url
16 | author {
17 | login
18 | }
19 | createdAt
20 | mergedAt
21 | mergedBy {
22 | login
23 | }
24 | baseRefOid
25 | baseRefName
26 | headRefOid
27 | headRefName
28 | changedFiles
29 |
30 | labels(first: 10) {
31 | totalCount
32 | pageInfo {
33 | hasNextPage
34 | endCursor
35 | }
36 | nodes {
37 | name
38 | }
39 | }
40 |
41 | commits(first: 10) {
42 | totalCount
43 | pageInfo {
44 | hasNextPage
45 | endCursor
46 | }
47 | nodes {
48 | commit {
49 | oid
50 | message
51 | changedFilesIfAvailable
52 | authoredDate
53 | author {
54 | user {
55 | login
56 | }
57 | }
58 | committedDate
59 | committer {
60 | user {
61 | login
62 | }
63 | }
64 | parents(first: 2) {
65 | nodes {
66 | oid
67 | }
68 | }
69 | }
70 | }
71 | }
72 |
73 | reviews(first: 10) {
74 | totalCount
75 | pageInfo {
76 | hasNextPage
77 | endCursor
78 | }
79 | nodes {
80 | id
81 | author {
82 | login
83 | }
84 | state
85 | body
86 | submittedAt
87 | updatedAt
88 | commit {
89 | oid
90 | }
91 | comments(first: 10) {
92 | totalCount
93 | pageInfo {
94 | hasNextPage
95 | endCursor
96 | }
97 | nodes {
98 | id
99 | author {
100 | login
101 | }
102 | body
103 | createdAt
104 | updatedAt
105 | path
106 | diffHunk
107 | line
108 | startLine
109 | originalLine
110 | originalStartLine
111 | replyTo {
112 | id
113 | }
114 | isMinimized,
115 | minimizedReason,
116 | commit {
117 | oid
118 | }
119 | originalCommit {
120 | oid
121 | }
122 | }
123 | }
124 | }
125 | }
126 |
127 | reviewThreads(first: 10) {
128 | totalCount
129 | pageInfo {
130 | hasNextPage
131 | endCursor
132 | }
133 | nodes {
134 | id
135 | isResolved
136 | isOutdated
137 | isCollapsed
138 | path
139 | startLine
140 | originalLine
141 | diffSide
142 | startDiffSide
143 | resolvedBy {
144 | login
145 | }
146 | comments(first: 10){
147 | totalCount
148 | pageInfo {
149 | hasNextPage
150 | endCursor
151 | }
152 | nodes {
153 | id
154 | }
155 | }
156 | }
157 | }
158 |
159 | closingIssuesReferences(first: 10) {
160 | totalCount
161 | pageInfo {
162 | hasNextPage
163 | endCursor
164 | }
165 | nodes {
166 | id
167 | number
168 | url
169 | title
170 | body
171 | state
172 | labels(first: 10) {
173 | totalCount
174 | pageInfo {
175 | hasNextPage
176 | endCursor
177 | }
178 | nodes {
179 | name
180 | }
181 | }
182 | comments(first: 10) {
183 | totalCount
184 | pageInfo {
185 | hasNextPage
186 | endCursor
187 | }
188 | nodes {
189 | id
190 | author {
191 | login
192 | }
193 | body
194 | createdAt
195 | updatedAt
196 | }
197 | }
198 | }
199 | }
200 | }
201 | }
202 | }
203 | }
--------------------------------------------------------------------------------
/src/swe_care/utils/patch.py:
--------------------------------------------------------------------------------
1 | import unidiff
2 | from loguru import logger
3 |
4 |
5 | def get_changed_file_paths(patch_content: str) -> list[str]:
6 | """
7 | Extract file paths that are changed in a patch.
8 |
9 | Args:
10 | patch_content: The patch content as a string
11 |
12 | Returns:
13 | A list of file paths that are modified in the patch
14 | """
15 | try:
16 | patch_set = unidiff.PatchSet(patch_content)
17 | changed_files = []
18 |
19 | for patched_file in patch_set:
20 | file_path = patched_file.source_file
21 | if file_path.startswith("a/"):
22 | file_path = file_path[2:] # Remove 'a/' prefix
23 |
24 | if file_path == "/dev/null" or patched_file.is_added_file:
25 | logger.debug(f"Skipping newly created file: {patched_file.path}")
26 | continue
27 |
28 | changed_files.append(file_path)
29 |
30 | return changed_files
31 | except Exception:
32 | # If parsing fails, return empty list
33 | return []
34 |
35 |
36 | def get_changed_lines_in_file(patch_content: str, file_path: str) -> set[int]:
37 | """
38 | Extract line numbers that were changed in a specific file from a patch.
39 |
40 | Args:
41 | patch_content: The patch content as a string
42 | file_path: The file path to analyze
43 |
44 | Returns:
45 | A set of line numbers that were modified in the file
46 | """
47 | try:
48 | patch_set = unidiff.PatchSet(patch_content)
49 | changed_lines = set()
50 |
51 | for patched_file in patch_set:
52 | # Check if this is the file we're interested in
53 | source_file = patched_file.source_file
54 | target_file = patched_file.target_file
55 |
56 | # Remove 'a/' and 'b/' prefixes
57 | if source_file.startswith("a/"):
58 | source_file = source_file[2:]
59 | if target_file.startswith("b/"):
60 | target_file = target_file[2:]
61 |
62 | if source_file == file_path or target_file == file_path:
63 | # Analyze hunks to find changed lines
64 | for hunk in patched_file:
65 | for line in hunk:
66 | if line.is_added or line.is_removed:
67 | # For removed lines, use source line number
68 | # For added lines, use target line number
69 | if line.is_removed and line.source_line_no:
70 | changed_lines.add(line.source_line_no)
71 | elif line.is_added and line.target_line_no:
72 | changed_lines.add(line.target_line_no)
73 |
74 | return changed_lines
75 | except Exception as e:
76 | logger.warning(f"Failed to parse patch for file {file_path}: {e}")
77 | return set()
78 |
79 |
80 | def is_line_changed_in_patch(
81 | patch_content: str, file_path: str, line_number: int
82 | ) -> bool:
83 | """
84 | Check if a specific line was changed in the patch.
85 |
86 | Args:
87 | patch_content: The patch content as a string
88 | file_path: The file path to check
89 | line_number: The line number to check
90 |
91 | Returns:
92 | True if the line was changed, False otherwise
93 | """
94 | changed_lines = get_changed_lines_in_file(patch_content, file_path)
95 | return line_number in changed_lines
96 |
97 |
98 | def extract_new_file_content_from_patch(
99 | patch_content: str, file_path: str
100 | ) -> str | None:
101 | """
102 | Extract the full content of a newly created file from a patch.
103 |
104 | Args:
105 | patch_content: The patch content as a string
106 | file_path: The file path to extract content for
107 |
108 | Returns:
109 | The full content of the newly created file if found, None otherwise
110 | """
111 | try:
112 | patch_set = unidiff.PatchSet(patch_content)
113 |
114 | for patched_file in patch_set:
115 | # Check if this is the file we're interested in
116 | source_file = patched_file.source_file
117 | target_file = patched_file.target_file
118 |
119 | # Remove 'a/' and 'b/' prefixes
120 | if source_file.startswith("a/"):
121 | source_file = source_file[2:]
122 | if target_file.startswith("b/"):
123 | target_file = target_file[2:]
124 |
125 | # Check if this is a newly created file (source is /dev/null)
126 | if target_file == file_path and (
127 | patched_file.source_file == "/dev/null" or patched_file.is_added_file
128 | ):
129 | # Extract all added lines to reconstruct the file content
130 | file_lines = []
131 | for hunk in patched_file:
132 | for line in hunk:
133 | if line.is_added:
134 | # Remove the '+' prefix and keep the content
135 | content = line.value
136 | # line.value includes the newline character at the end for most lines
137 | file_lines.append(content)
138 |
139 | # Join all lines to form the complete file content
140 | return "".join(file_lines)
141 |
142 | return None
143 | except Exception as e:
144 | logger.warning(
145 | f"Failed to extract new file content for {file_path} from patch: {e}"
146 | )
147 | return None
148 |
--------------------------------------------------------------------------------
/src/swe_care/schema/dataset.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any
3 |
4 | from dataclasses_json import dataclass_json
5 |
6 |
7 | @dataclass_json
8 | @dataclass
9 | class ResolvedIssue:
10 | """Schema for resolved issue instances."""
11 |
12 | number: int
13 | title: str
14 | body: str
15 |
16 |
17 | @dataclass_json
18 | @dataclass
19 | class IssueResolvingTaskMetadata:
20 | """Schema for metadata instances."""
21 |
22 | problem_domain: str | None
23 | """The problem domain"""
24 | difficulty: str | None
25 | """The task difficulty of the pull request"""
26 |
27 |
28 | @dataclass_json
29 | @dataclass
30 | class IssueResolvingTaskInstance:
31 | """Schema for Issue Resolving task instances."""
32 |
33 | instance_id: str
34 | repo: str
35 | language: str
36 | pull_number: int
37 | title: str
38 | body: str
39 | created_at: str
40 | problem_statement: str
41 | hints_text: str
42 | resolved_issues: list[ResolvedIssue]
43 | base_commit: str
44 | patch: str
45 | test_patch: str
46 | env_setup_config: dict[str, Any]
47 | FAIL_TO_PASS: list[str]
48 | PASS_TO_PASS: list[str]
49 | version: str
50 | metadata: IssueResolvingTaskMetadata
51 |
52 |
53 | @dataclass_json
54 | @dataclass
55 | class CodeReviewTaskMetadata:
56 | """Schema for code review metadata instances."""
57 |
58 | problem_domain: str | None
59 | """The problem domain"""
60 | difficulty: str | None
61 | """The task difficulty of the pull request"""
62 | estimated_review_effort: int | None
63 | """The estimated review effort of the pull request between 1 and 5"""
64 |
65 |
66 | @dataclass_json
67 | @dataclass
68 | class ReferenceReviewComment:
69 | """Schema for reference review comment instances."""
70 |
71 | text: str
72 | """The text of the review comment"""
73 | path: str
74 | """The path of the review comment"""
75 | diff_hunk: str | None
76 | """The diff hunk of the review comment"""
77 | line: int | None
78 | """The line number of the review comment"""
79 | start_line: int | None
80 | """The start line number of the review comment"""
81 | original_line: int | None
82 | """The original line number of the review comment"""
83 | original_start_line: int | None
84 | """The original start line number of the review comment"""
85 |
86 |
87 | @dataclass_json
88 | @dataclass
89 | class CommitToReview:
90 | """Schema for commit to review instances."""
91 |
92 | head_commit: str
93 | head_commit_message: str
94 | patch_to_review: str
95 |
96 |
97 | @dataclass_json
98 | @dataclass
99 | class CodeReviewTaskInstance:
100 | """Schema for Code Review task instances."""
101 |
102 | instance_id: str
103 | """A formatted instance identifier, usually as repo_owner__repo_name-PR-number@commit_sha_short"""
104 | repo: str
105 | """The repository owner/name identifier from GitHub"""
106 | language: str
107 | """The main programming language of the repo"""
108 | pull_number: int
109 | """Number of PR this task instance is from"""
110 | title: str
111 | """The PR title"""
112 | body: str
113 | """The PR body"""
114 | created_at: str
115 | """The creation date of the pull request"""
116 | problem_statement: str
117 | """The issue(s) title and body"""
118 | hints_text: str
119 | """Comments made on the issue(s) prior to the creation of the commit to review of the solution PR."""
120 | resolved_issues: list[ResolvedIssue]
121 | """The resolved issues of the pull request"""
122 | base_commit: str
123 | """The commit hash of the repository representing the HEAD of the repository before the solution PR is applied."""
124 | commit_to_review: CommitToReview
125 | """The commit to review of the pull request"""
126 | reference_review_comments: list[ReferenceReviewComment]
127 | """The reference review comments of the commit to review"""
128 | merged_commit: str
129 | """The merged commit of the pull request"""
130 | merged_patch: str
131 | """The gold patch, the patch generated by the PR, that resolved the issue"""
132 | metadata: CodeReviewTaskMetadata
133 |
134 | @classmethod
135 | def generate_instance_id(cls, repo: str, pull_number: int, commit: str) -> str:
136 | """Generate an instance ID from repo, pull number, and commit.
137 |
138 | Args:
139 | repo: The repository owner/name identifier from GitHub
140 | pull_number: Number of PR this task instance is from
141 | commit: The commit hash (will be shortened to first 7 characters)
142 |
143 | Returns:
144 | A formatted instance identifier as repo_owner__repo_name-PR-number@commit_sha_short
145 | """
146 | repo_formatted = repo.replace("/", "__")
147 | commit_short = commit[:7]
148 | return f"{repo_formatted}-{pull_number}@{commit_short}"
149 |
150 | @classmethod
151 | def parse_instance_id(cls, instance_id: str) -> tuple[str, int, str]:
152 | """Parse an instance ID into repo, pull number, and commit.
153 |
154 | Args:
155 | instance_id: The instance ID to parse
156 |
157 | Returns:
158 | A tuple of repo, pull number, and commit
159 |
160 | Raises:
161 | ValueError: If the instance ID format is invalid
162 | """
163 | if "@" not in instance_id:
164 | raise ValueError(
165 | f"Invalid instance ID format: {instance_id}. Expected format: repo_owner__repo_name-PR-number@commit_sha_short"
166 | )
167 |
168 | repo_pr_part, commit = instance_id.split("@", 1)
169 |
170 | if "-" not in repo_pr_part:
171 | raise ValueError(
172 | f"Invalid instance ID format: {instance_id}. Expected format: repo_owner__repo_name-PR-number@commit_sha_short"
173 | )
174 |
175 | repo_formatted, pull_number_str = repo_pr_part.rsplit("-", 1)
176 |
177 | try:
178 | pull_number = int(pull_number_str)
179 | except ValueError:
180 | raise ValueError(f"Invalid pull number in instance ID: {instance_id}")
181 |
182 | repo = repo_formatted.replace("__", "/")
183 |
184 | return repo, pull_number, commit
185 |
--------------------------------------------------------------------------------
/src/swe_care/utils/file_source_retrieval.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for retrieving relevant files based on different strategies.
3 | """
4 |
5 | import string
6 | from pathlib import Path
7 | from typing import Literal, Optional
8 |
9 | import nltk
10 | from loguru import logger
11 | from nltk.tokenize import word_tokenize
12 | from rank_bm25 import BM25Okapi
13 |
14 | from swe_care.schema.collect import LabeledReviewComment
15 | from swe_care.schema.dataset import ReferenceReviewComment
16 | from swe_care.utils.extract_prs_data import (
17 | fetch_repo_file_content,
18 | fetch_repo_files_content_by_retrieval,
19 | )
20 | from swe_care.utils.patch import get_changed_file_paths
21 |
22 | try:
23 | nltk.download("punkt_tab")
24 | except Exception:
25 | logger.error(
26 | "Failed to download punkt_tab, maybe you need to use VPN to download it?"
27 | )
28 | raise
29 |
30 |
31 | def get_relevant_files(
32 | review_comment: ReferenceReviewComment | LabeledReviewComment,
33 | file_source: Literal[
34 | "none",
35 | "base_changed_files",
36 | "reviewed_file",
37 | "retrieved_base_changed_files",
38 | "retrieved_all_files",
39 | ],
40 | repo: str,
41 | base_commit: str,
42 | patch_to_review: str,
43 | tokens: Optional[list[str]] = None,
44 | retrieval_max_files: int = 5,
45 | retrieval_output_dir: Optional[Path | str] = None,
46 | changed_files: Optional[dict[str, str]] = None,
47 | ) -> dict[str, str]:
48 | """Get relevant files based on file_source strategy.
49 |
50 | Args:
51 | review_comment: The review comment (can be ReferenceReviewComment or LabeledReviewComment)
52 | file_source: Source for file content
53 | repo: Repository in format 'owner/name'
54 | base_commit: Base commit SHA
55 | patch_to_review: Patch content to review
56 | tokens: GitHub API tokens
57 | retrieval_max_files: Maximum number of files for retrieval
58 | retrieval_output_dir: Output directory for retrieval operations
59 | changed_files: Pre-computed changed files (optional, to avoid re-fetching)
60 |
61 | Returns:
62 | Dictionary mapping file paths to their contents
63 | """
64 | if file_source == "none":
65 | return {}
66 |
67 | if file_source == "reviewed_file" and review_comment.path:
68 | try:
69 | content = fetch_repo_file_content(
70 | repo, base_commit, review_comment.path, tokens, patch_to_review
71 | )
72 | return {review_comment.path: content}
73 | except Exception as e:
74 | logger.warning(f"Failed to fetch content for {review_comment.path}: {e}")
75 | return {}
76 |
77 | elif file_source == "base_changed_files":
78 | if changed_files is not None:
79 | return changed_files
80 | else:
81 | return get_changed_files_in_patch(
82 | repo, base_commit, patch_to_review, tokens
83 | )
84 |
85 | elif file_source == "retrieved_base_changed_files":
86 | if not review_comment.diff_hunk:
87 | return {}
88 |
89 | # Get changed files first
90 | if changed_files is None:
91 | changed_files = get_changed_files_in_patch(
92 | repo, base_commit, patch_to_review, tokens
93 | )
94 |
95 | if not changed_files:
96 | return {}
97 |
98 | # Use BM25 to retrieve relevant files
99 | def preprocess(text):
100 | text = text.lower()
101 | text = text.translate(str.maketrans("", "", string.punctuation))
102 | return word_tokenize(text)
103 |
104 | try:
105 | query_tokens = preprocess(review_comment.diff_hunk)
106 | bm25 = BM25Okapi(
107 | [preprocess(content) for content in changed_files.values()]
108 | )
109 | doc_scores = bm25.get_scores(query_tokens)
110 |
111 | top_indices = sorted(
112 | range(len(doc_scores)),
113 | key=lambda i: doc_scores[i],
114 | reverse=True,
115 | )[: min(retrieval_max_files, len(changed_files))]
116 |
117 | file_paths = list(changed_files.keys())
118 | relevant_files = {}
119 | for idx in top_indices:
120 | file_path = file_paths[idx]
121 | relevant_files[file_path] = changed_files[file_path]
122 | return relevant_files
123 | except Exception as e:
124 | logger.warning(f"Failed to retrieve content for diff hunk: {e}")
125 | return {}
126 |
127 | elif file_source == "retrieved_all_files":
128 | if not review_comment.diff_hunk:
129 | return {}
130 |
131 | if retrieval_output_dir is None:
132 | raise ValueError(
133 | "retrieval_output_dir is required when file_source is 'retrieved_all_files'"
134 | )
135 |
136 | # Convert Path to str if needed
137 | if isinstance(retrieval_output_dir, Path):
138 | retrieval_output_dir = str(retrieval_output_dir)
139 |
140 | return fetch_repo_files_content_by_retrieval(
141 | repo=repo,
142 | commit=base_commit,
143 | query=review_comment.diff_hunk,
144 | retrieval_output_dir=retrieval_output_dir,
145 | tokens=tokens,
146 | max_files=retrieval_max_files,
147 | )
148 |
149 | return {}
150 |
151 |
152 | def get_changed_files_in_patch(
153 | repo: str,
154 | base_commit: str,
155 | patch_to_review: str,
156 | tokens: list[str] | None = None,
157 | ) -> dict[str, str]:
158 | """
159 | Get file path and file content that are changed in the patch.
160 | """
161 | changed_files = {}
162 |
163 | logger.debug(f"Getting changed file paths from {base_commit} to commit_to_review")
164 | changed_file_paths = get_changed_file_paths(patch_to_review)
165 | logger.debug(f"Changed file paths: {changed_file_paths}")
166 |
167 | # Fetch file contents
168 | for file_path in changed_file_paths:
169 | try:
170 | logger.debug(f"Fetching content for {file_path}")
171 | content = fetch_repo_file_content(
172 | repo, base_commit, file_path, tokens, patch_to_review
173 | )
174 | changed_files[file_path] = content
175 | except Exception as e:
176 | logger.warning(f"Failed to fetch content for {file_path}: {e}")
177 | changed_files[file_path] = ""
178 |
179 | # Filter out files without content and return only the files we fetched
180 | result = {
181 | path: content for path, content in changed_files.items() if content is not None
182 | }
183 |
184 | logger.info(f"Retrieved {len(result)} changed files")
185 | return result
186 |
--------------------------------------------------------------------------------
/src/swe_care/utils/load.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from datasets import load_dataset
4 | from loguru import logger
5 |
6 | from swe_care.schema.dataset import CodeReviewTaskInstance
7 | from swe_care.schema.evaluation import CodeReviewEvaluationResult, CodeReviewPrediction
8 | from swe_care.schema.inference import CodeReviewInferenceInstance
9 |
10 |
11 | def load_code_review_dataset(
12 | dataset_name_or_path: Path | str = "inclusionAI/SWE-CARE",
13 | ) -> list[CodeReviewTaskInstance]:
14 | """Load the code review dataset instances from a JSONL file or Hugging Face dataset.
15 |
16 | Args:
17 | dataset_name_or_path: Either a file path to a local JSONL file, or a Hugging Face dataset name.
18 | Defaults to "inclusionAI/SWE-CARE".
19 |
20 | Returns:
21 | List of CodeReviewTaskInstance objects
22 |
23 | Raises:
24 | FileNotFoundError: If a file path is provided but the file doesn't exist
25 | Exception: If there's an error parsing the file or loading from Hugging Face
26 | """
27 | # Check if dataset_name_or_path is a file path
28 | if isinstance(dataset_name_or_path, str):
29 | path = Path(dataset_name_or_path)
30 | else:
31 | path = dataset_name_or_path
32 |
33 | # If it's an existing file, load from file
34 | if path.exists():
35 | logger.info("Loading dataset instances from file...")
36 |
37 | dataset_instances: list[CodeReviewTaskInstance] = []
38 |
39 | with open(path, "r") as f:
40 | for line_num, line in enumerate(f, 1):
41 | try:
42 | instance = CodeReviewTaskInstance.from_json(line.strip())
43 | dataset_instances.append(instance)
44 | except Exception as e:
45 | logger.error(f"Error processing line {line_num}: {e}")
46 | raise e
47 |
48 | logger.success(f"Loaded {len(dataset_instances)} dataset instances from file")
49 | return dataset_instances
50 |
51 | # Otherwise, load from Hugging Face
52 | else:
53 | logger.info(f"Loading dataset from Hugging Face: {dataset_name_or_path}")
54 |
55 | # Load the test split from Hugging Face
56 | dataset = load_dataset(
57 | str(dataset_name_or_path), split="test", revision="0.2.0"
58 | )
59 |
60 | # Convert Hugging Face dataset to list of CodeReviewTaskInstance
61 | dataset_instances: list[CodeReviewTaskInstance] = []
62 |
63 | logger.info("Processing test split")
64 | for idx, item in enumerate(dataset):
65 | try:
66 | # Convert the Hugging Face dataset item to CodeReviewTaskInstance
67 | # Using from_dict method provided by dataclass_json
68 | instance = CodeReviewTaskInstance.from_dict(item)
69 | dataset_instances.append(instance)
70 | except Exception as e:
71 | logger.error(f"Error converting item {idx} in test split: {e}")
72 | raise e
73 |
74 | logger.success(
75 | f"Loaded {len(dataset_instances)} dataset instances from Hugging Face test split"
76 | )
77 | return dataset_instances
78 |
79 |
80 | def load_code_review_predictions(
81 | predictions_path: Path | str,
82 | ) -> list[CodeReviewPrediction]:
83 | """Load the code review predictions from the JSONL file."""
84 | if isinstance(predictions_path, str):
85 | predictions_path = Path(predictions_path)
86 | logger.info("Loading predictions...")
87 |
88 | if not predictions_path.exists():
89 | raise FileNotFoundError(f"Predictions file not found: {predictions_path}")
90 |
91 | predictions: list[CodeReviewPrediction] = []
92 |
93 | with open(predictions_path, "r") as f:
94 | for line_num, line in enumerate(f, 1):
95 | try:
96 | prediction = CodeReviewPrediction.from_json(line.strip())
97 | predictions.append(prediction)
98 | except Exception as e:
99 | logger.error(f"Error processing line {line_num}: {e}")
100 | raise e
101 | logger.success(f"Loaded {len(predictions)} predictions")
102 | return predictions
103 |
104 |
105 | def load_code_review_text(
106 | dataset_file: Path | str,
107 | ) -> list[CodeReviewInferenceInstance]:
108 | """Load the code review text dataset instances from the JSONL file.
109 |
110 | Args:
111 | dataset_file: Path to the input JSONL file containing CodeReviewInferenceInstance objects
112 |
113 | Returns:
114 | List of CodeReviewInferenceInstance objects
115 |
116 | Raises:
117 | FileNotFoundError: If the dataset file doesn't exist
118 | Exception: If there's an error parsing the file
119 | """
120 | if isinstance(dataset_file, str):
121 | dataset_file = Path(dataset_file)
122 | logger.info("Loading inference text dataset instances...")
123 |
124 | if not dataset_file.exists():
125 | raise FileNotFoundError(f"Dataset file not found: {dataset_file}")
126 |
127 | dataset_instances: list[CodeReviewInferenceInstance] = []
128 |
129 | with open(dataset_file, "r") as f:
130 | for line_num, line in enumerate(f, 1):
131 | try:
132 | instance = CodeReviewInferenceInstance.from_json(line.strip())
133 | dataset_instances.append(instance)
134 | except Exception as e:
135 | logger.error(f"Error processing line {line_num}: {e}")
136 | raise e
137 |
138 | logger.success(f"Loaded {len(dataset_instances)} inference text dataset instances")
139 | return dataset_instances
140 |
141 |
142 | def load_code_review_eval_result(
143 | eval_result_file: Path | str,
144 | ) -> list[CodeReviewEvaluationResult]:
145 | """Load the code review evaluation results from the JSONL file.
146 |
147 | Args:
148 | eval_result_file: Path to the evaluation result JSONL file
149 |
150 | Returns:
151 | List of CodeReviewEvaluationResult objects
152 |
153 | Raises:
154 | FileNotFoundError: If the evaluation result file doesn't exist
155 | Exception: If there's an error parsing the file
156 | """
157 | if isinstance(eval_result_file, str):
158 | eval_result_file = Path(eval_result_file)
159 | logger.info(f"Loading evaluation results from {eval_result_file}...")
160 |
161 | if not eval_result_file.exists():
162 | raise FileNotFoundError(f"Evaluation result file not found: {eval_result_file}")
163 |
164 | eval_results: list[CodeReviewEvaluationResult] = []
165 |
166 | with open(eval_result_file, "r") as f:
167 | for line_num, line in enumerate(f, 1):
168 | try:
169 | result = CodeReviewEvaluationResult.from_json(line.strip())
170 | eval_results.append(result)
171 | except Exception as e:
172 | logger.error(f"Error processing line {line_num}: {e}")
173 | raise e
174 |
175 | logger.success(f"Loaded {len(eval_results)} evaluation results")
176 | return eval_results
177 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # SWE-CARE Evaluation Pipeline Scripts
2 |
3 | This directory contains scripts to automate the SWE-CARE evaluation pipeline and analyze results.
4 |
5 | ## `run_eval_pipeline.py`
6 |
7 | A bootstrap script that runs the complete evaluation pipeline:
8 |
9 | 1. **Generate text datasets** from collected SWE-CARE data
10 | 2. **Run LLM inference** on code review tasks
11 | 3. **Evaluate predictions** using LLM evaluator (fixed to OpenAI o3)
12 |
13 | ### Prerequisites
14 |
15 | Set up the required environment variables:
16 |
17 | ```bash
18 | # Required
19 | export OPENAI_API_KEY="your-openai-api-key"
20 | export LLM_EVALUATOR_OPENAI_API_KEY="your-o3-evaluation-api-key"
21 |
22 | # Optional
23 | export ANTHROPIC_API_KEY="your-anthropic-api-key"
24 | export OPENAI_BASE_URL="https://your-custom-openai-endpoint"
25 | export LLM_EVALUATOR_OPENAI_BASE_URL="https://your-custom-o3-endpoint"
26 | ```
27 |
28 | ### Usage
29 |
30 | Basic usage with no file context:
31 |
32 | ```bash
33 | # Using default Hugging Face dataset
34 | python scripts/run_eval_pipeline.py \
35 | --output-dir results/pipeline_output \
36 | --model gpt-4o \
37 | --model-provider openai \
38 | --file-source none
39 |
40 | # Using local dataset file
41 | python scripts/run_eval_pipeline.py \
42 | --dataset-name-or-path results/dataset/code_review_task_instances.jsonl \
43 | --output-dir results/pipeline_output \
44 | --model gpt-4o \
45 | --model-provider openai \
46 | --file-source none
47 | ```
48 |
49 | With oracle file source and custom model args:
50 |
51 | ```bash
52 | python scripts/run_eval_pipeline.py \
53 | --dataset-name-or-path results/dataset/code_review_task_instances.jsonl \
54 | --output-dir results/pipeline_output \
55 | --model claude-3-5-sonnet-20241022 \
56 | --model-provider anthropic \
57 | --model-args "temperature=0.5,max_tokens=4096" \
58 | --file-source oracle \
59 | --github-tokens "token1" "token2"
60 | ```
61 |
62 | With BM25 retrieval:
63 |
64 | ```bash
65 | python scripts/run_eval_pipeline.py \
66 | --dataset-name-or-path results/dataset/code_review_task_instances.jsonl \
67 | --output-dir results/pipeline_output \
68 | --model "models/gemini-2.5-pro" \
69 | --model-provider openai \
70 | --file-source bm25 \
71 | --k 10 \
72 | --retrieval-output-dir results/retrieval_output
73 | ```
74 |
75 | Using Tree-sitter skeletons for Python files (works with any file-source):
76 |
77 | ```bash
78 | python scripts/run_eval_pipeline.py \
79 | --dataset-name-or-path results/dataset/code_review_task_instances.jsonl \
80 | --output-dir results/pipeline_output \
81 | --model gpt-4o \
82 | --model-provider openai \
83 | --file-source oracle \
84 | --use-skeleton
85 | ```
86 |
87 | ### Arguments
88 |
89 | **Required:**
90 |
91 | - `--output-dir`: Directory to save all pipeline outputs
92 | - `--model`: Model name to use for inference
93 | - `--model-provider`: Model provider (openai, anthropic, deepseek, qwen, moonshot, gemini)
94 |
95 | **Optional:**
96 |
97 | - `--dataset-name-or-path`: Path to the input SWE-CARE dataset file or Hugging Face dataset name (default: inclusionAI/SWE-CARE)
98 | - `--model-args`: Comma-separated model arguments (e.g., 'temperature=0.7,top_p=0.9')
99 | - `--file-source`: Source strategy for files (none, oracle, bm25, all)
100 | - `--k`: Maximum number of files to use (required for bm25/all)
101 | - `--retrieval-output-dir`: Output directory for retrieval operations (required for bm25/all)
102 | - `--github-tokens`: GitHub API token(s) for fetching data
103 | - `--jobs`: Number of parallel jobs (default: 2)
104 | - `--skip-existing`: Skip instances that already have predictions
105 |
106 | ### Output Structure
107 |
108 | The script creates the following directory structure:
109 |
110 | ```
111 |
112 | /
113 | ├── pipeline_config_YYYYMMDD_HHMMSS.json # Complete pipeline configuration (timestamped)
114 | ├── pipeline_YYYYMMDD_HHMMSS.log # Detailed execution log (timestamped)
115 | ├── code_review_text/ # Generated text datasets
116 | │ └── __[**skeleton].jsonl
117 | │ └── ****k[**skeleton].jsonl # For bm25/all with k parameter
118 | ├── predictions/ # Model predictions organized by model
119 | │ └── / # Model-specific subdirectory
120 | │ └── **.jsonl
121 | └── evaluation/ # Evaluation results organized by model
122 | └── / # Model-specific subdirectory
123 | └── **_report_YYYYMMDD_HHMMSS.jsonl
124 |
125 | ```
126 |
127 | ### Notes
128 |
129 | - The script automatically handles model names with slashes (e.g., `models/gemini-2.5-pro`)
130 | - Model predictions and evaluation results are organized in subdirectories by model name for better organization
131 | - LLM evaluation is fixed to use OpenAI o3 model with `temperature=1`
132 | - Use separate API keys for inference and evaluation via environment variables
133 | - All intermediate results are saved for debugging and analysis
134 | - The pipeline configuration and logs are timestamped for reproducibility
135 | - Evaluation report detection ensures only reports generated by the current run are used
136 | - Timestamps follow the format YYYYMMDD_HHMMSS for easy sorting and identification
137 |
138 | ## `eval_report.py`
139 |
140 | A comprehensive analysis script that generates detailed evaluation reports from pipeline results:
141 |
142 | 1. **Collects evaluation results** from multiple models and settings
143 | 2. **Aggregates performance metrics** across different dimensions
144 | 3. **Handles missing instances** by assigning score 0 for fair comparison
145 | 4. **Generates rankings** of model-setting configurations
146 |
147 | ### Prerequisites
148 |
149 | Ensure you have run the evaluation pipeline first using `run_eval_pipeline.py`.
150 |
151 | ### Usage
152 |
153 | Basic usage:
154 |
155 | ```bash
156 | # Using local dataset file
157 | python scripts/eval_report.py \
158 | --dataset-name-or-path results/dataset/code_review_task_instances.jsonl \
159 | --eval-output-dir results/pipeline_output/evaluation \
160 | --report-output-file results/evaluation_report.json
161 |
162 | # Using default Hugging Face dataset
163 | python scripts/eval_report.py \
164 | --eval-output-dir results/pipeline_output/evaluation \
165 | --report-output-file results/evaluation_report.json
166 | ```
167 |
168 | ### Arguments
169 |
170 | **Required:**
171 |
172 | - `--eval-output-dir`: Directory containing evaluation results (organized by model)
173 | - `--report-output-file`: Path for the output JSON report file
174 |
175 | **Optional:**
176 |
177 | - `--dataset-name-or-path`: Path to the dataset file or Hugging Face dataset name (default: inclusionAI/SWE-CARE)
178 |
179 | ### Notes
180 |
181 | - The script expects evaluation results to be organized in subdirectories by model name
182 | - Filename pattern must match: `{dataset}_____report_YYYYMMDD_HHMMSS.jsonl`
183 | - For bm25 settings: `{dataset}__bm25__k___report_YYYYMMDD_HHMMSS.jsonl`
184 | - All scores are averaged including zeros for missing instances to ensure fair comparison
185 |
--------------------------------------------------------------------------------
/src/swe_care/harness/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from pathlib import Path
4 | from typing import Any
5 |
6 | from loguru import logger
7 |
8 | import swe_care.harness.code_review_eval
9 | from swe_care.harness.code_review_eval import EvaluatorType, code_review_eval
10 | from swe_care.utils.llm_models import (
11 | get_available_models_and_providers,
12 | parse_model_args,
13 | )
14 |
15 |
16 | def parse_evaluator_args(evaluator_args_str: str | None) -> dict[str, dict[str, Any]]:
17 | """Parse evaluator args string into a dictionary.
18 |
19 | Args:
20 | evaluator_args_str: String in format 'evaluator1:arg1=value1,arg2=value2;evaluator2:arg1=value1'
21 |
22 | Returns:
23 | Dictionary mapping evaluator types to their kwargs
24 | """
25 | if not evaluator_args_str:
26 | return {}
27 |
28 | result = {}
29 |
30 | # Split by semicolon to get each evaluator's args
31 | evaluator_parts = evaluator_args_str.split(";")
32 |
33 | for part in evaluator_parts:
34 | if ":" not in part:
35 | continue
36 |
37 | evaluator_type, args_part = part.split(":", 1)
38 | evaluator_type = evaluator_type.strip()
39 |
40 | # Parse the args part using the existing parse_model_args function
41 | kwargs = parse_model_args(args_part)
42 | result[evaluator_type] = kwargs
43 |
44 | return result
45 |
46 |
47 | # Mapping of subcommands to their function names
48 | SUBCOMMAND_MAP = {
49 | "code_review_eval": {
50 | "function": code_review_eval,
51 | "help": swe_care.harness.code_review_eval.__doc__,
52 | },
53 | }
54 |
55 |
56 | def create_global_parser():
57 | """Create a parser with global arguments that can be used as a parent parser."""
58 | global_parser = argparse.ArgumentParser(add_help=False)
59 | global_parser.add_argument(
60 | "--output-dir",
61 | type=Path,
62 | required=True,
63 | help="Path to output directory",
64 | )
65 | return global_parser
66 |
67 |
68 | def get_args():
69 | # Parse command line manually to handle flexible argument order
70 | args = sys.argv[1:]
71 |
72 | # Find the subcommand
73 | subcommands = list(SUBCOMMAND_MAP.keys())
74 | subcommand = None
75 | subcommand_index = None
76 |
77 | for i, arg in enumerate(args):
78 | if arg in subcommands:
79 | subcommand = arg
80 | subcommand_index = i
81 | break
82 |
83 | # Create global parser
84 | global_parser = create_global_parser()
85 |
86 | if subcommand is None:
87 | # No subcommand found, use normal argparse
88 | parser = argparse.ArgumentParser(
89 | prog="swe_care.harness",
90 | description="Evaluation tools for SWE-CARE",
91 | parents=[global_parser],
92 | )
93 |
94 | subparsers = parser.add_subparsers(dest="command", help="Available commands")
95 | for cmd, info in SUBCOMMAND_MAP.items():
96 | subparsers.add_parser(cmd, help=info["help"])
97 |
98 | return parser.parse_args(args)
99 |
100 | # Create the appropriate subcommand parser with global parser as parent
101 | match subcommand:
102 | case "code_review_eval":
103 | sub_parser = argparse.ArgumentParser(
104 | prog=f"swe_care.harness {subcommand}",
105 | parents=[global_parser],
106 | description=SUBCOMMAND_MAP[subcommand]["help"],
107 | )
108 | sub_parser.add_argument(
109 | "--dataset-name-or-path",
110 | type=str,
111 | required=False,
112 | default="inclusionAI/SWE-CARE",
113 | help="Path to the dataset file or Hugging Face dataset name (default: inclusionAI/SWE-CARE)",
114 | )
115 | sub_parser.add_argument(
116 | "--predictions-path",
117 | type=Path,
118 | required=True,
119 | help="Path to the predictions file or directory containing predictions",
120 | )
121 | sub_parser.add_argument(
122 | "--evaluator",
123 | type=EvaluatorType,
124 | nargs="+",
125 | required=True,
126 | choices=[e.value for e in EvaluatorType],
127 | help="Evaluator type to use",
128 | )
129 |
130 | available_providers, available_models = get_available_models_and_providers()
131 |
132 | sub_parser.add_argument(
133 | "--model",
134 | type=str,
135 | required=False,
136 | help=f"Model name to use for LLM evaluation. Available models: {', '.join(available_models)}",
137 | )
138 | sub_parser.add_argument(
139 | "--model-provider",
140 | type=str,
141 | required=False,
142 | choices=available_providers,
143 | help=f"Model provider for LLM evaluation. Available providers: {', '.join(available_providers)}",
144 | )
145 | sub_parser.add_argument(
146 | "--model-args",
147 | type=str,
148 | required=False,
149 | default=None,
150 | help="Comma-separated model arguments for LLM evaluation (e.g., 'temperature=0.7,top_p=0.9')",
151 | )
152 | sub_parser.add_argument(
153 | "--evaluator-args",
154 | type=str,
155 | required=False,
156 | default=None,
157 | help="Evaluator-specific arguments in format 'evaluator1:arg1=value1,arg2=value2;evaluator2:arg1=value1'",
158 | )
159 | sub_parser.add_argument(
160 | "--jobs",
161 | type=int,
162 | default=2,
163 | help="Number of parallel jobs to run (default: 2)",
164 | )
165 |
166 | # Parse all arguments with the subcommand parser
167 | # This will include both global and subcommand-specific arguments
168 | # Remove the subcommand itself from args
169 | args_without_subcommand = args[:subcommand_index] + args[subcommand_index + 1 :]
170 | final_namespace = sub_parser.parse_args(args_without_subcommand)
171 | final_namespace.command = subcommand
172 |
173 | return final_namespace
174 |
175 |
176 | def main():
177 | args = get_args()
178 |
179 | if args.command in SUBCOMMAND_MAP:
180 | # Get the function from the mapping
181 | cmd_info = SUBCOMMAND_MAP[args.command]
182 | function = cmd_info["function"]
183 |
184 | # Prepare common arguments
185 | common_kwargs = {"output_dir": args.output_dir}
186 |
187 | # Add specific arguments based on subcommand
188 | match args.command:
189 | case "code_review_eval":
190 | # Parse evaluator args
191 | evaluator_kwargs = parse_evaluator_args(args.evaluator_args)
192 |
193 | function(
194 | dataset_name_or_path=args.dataset_name_or_path,
195 | predictions_path=args.predictions_path,
196 | evaluator_types=args.evaluator,
197 | model=args.model,
198 | model_provider=args.model_provider,
199 | model_args=args.model_args,
200 | evaluator_kwargs=evaluator_kwargs,
201 | jobs=args.jobs,
202 | **common_kwargs,
203 | )
204 | else:
205 | logger.info("Please specify a command. Use --help for available commands.")
206 |
207 |
208 | if __name__ == "__main__":
209 | main()
210 |
--------------------------------------------------------------------------------
/src/swe_care/utils/llm_models/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from loguru import logger
4 |
5 | from swe_care.utils.llm_models.clients import (
6 | AnthropicClient,
7 | BaseModelClient,
8 | DeepSeekClient,
9 | GeminiClient,
10 | MoonshotClient,
11 | OpenAIClient,
12 | QwenClient,
13 | )
14 |
15 | # Map of available LLM clients cited from https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
16 | LLM_CLIENT_MAP = {
17 | "openai": {
18 | "client_class": OpenAIClient,
19 | "models": [
20 | {"name": "gpt-4o", "max_input_tokens": 128000},
21 | {"name": "gpt-4o-mini", "max_input_tokens": 128000},
22 | {"name": "gpt-4.1", "max_input_tokens": 1047576},
23 | {"name": "gpt-4.5-preview", "max_input_tokens": 128000},
24 | {"name": "gpt-5", "max_input_tokens": 128000},
25 | {"name": "gpt-5-chat", "max_input_tokens": 128000},
26 | {"name": "o1", "max_input_tokens": 200000},
27 | {"name": "o1-mini", "max_input_tokens": 128000},
28 | {"name": "o3", "max_input_tokens": 200000},
29 | {"name": "o3-mini", "max_input_tokens": 200000},
30 | ],
31 | },
32 | "anthropic": {
33 | "client_class": AnthropicClient,
34 | "models": [
35 | {"name": "claude-opus-4", "max_input_tokens": 200000},
36 | {"name": "claude-sonnet-4", "max_input_tokens": 200000},
37 | {"name": "claude-3-7-sonnet", "max_input_tokens": 200000},
38 | ],
39 | },
40 | "deepseek": {
41 | "client_class": DeepSeekClient,
42 | "models": [
43 | {"name": "deepseek-chat", "max_input_tokens": 128000}, # DeepSeek-V3.1
44 | {"name": "DeepSeek-V3.1", "max_input_tokens": 128000},
45 | {
46 | "name": "deepseek-reasoner",
47 | "max_input_tokens": 128000,
48 | }, # DeepSeek-V3.1 thinking
49 | ],
50 | },
51 | "qwen": {
52 | "client_class": QwenClient,
53 | "models": [
54 | {"name": "qwen3-32b", "max_input_tokens": 128000},
55 | {"name": "qwen3-30b-a3b", "max_input_tokens": 128000},
56 | {"name": "qwen3-235b-a22b", "max_input_tokens": 128000},
57 | ],
58 | },
59 | "moonshot": {
60 | "client_class": MoonshotClient,
61 | "models": [
62 | {"name": "kimi-k2-0711-preview", "max_input_tokens": 131072},
63 | {"name": "kimi-k2-0905-preview", "max_input_tokens": 131072},
64 | ],
65 | },
66 | "gemini": {
67 | "client_class": GeminiClient,
68 | "models": [
69 | {"name": "gemini-2.5-pro", "max_input_tokens": 1048576},
70 | ],
71 | },
72 | }
73 |
74 |
75 | def get_available_models_and_providers() -> tuple[list[str], list[str]]:
76 | """Get available models and providers from LLM_CLIENT_MAP."""
77 | available_providers = list(LLM_CLIENT_MAP.keys())
78 | available_models = []
79 | for provider_info in LLM_CLIENT_MAP.values():
80 | available_models.extend([model["name"] for model in provider_info["models"]])
81 | return available_providers, available_models
82 |
83 |
84 | def init_llm_client(
85 | model: str, model_provider: str, **model_kwargs: Any
86 | ) -> BaseModelClient:
87 | """Initialize an LLM client.
88 |
89 | Args:
90 | model: Model name
91 | model_provider: Provider name (openai, anthropic)
92 | **model_kwargs: Additional model arguments
93 |
94 | Returns:
95 | Initialized LLM client
96 |
97 | Raises:
98 | ValueError: If the model provider or model is not supported
99 | """
100 | if model_provider not in LLM_CLIENT_MAP:
101 | raise ValueError(
102 | f"Unsupported model provider: {model_provider}. "
103 | f"Supported providers: {list(LLM_CLIENT_MAP.keys())}"
104 | )
105 |
106 | provider_info = LLM_CLIENT_MAP[model_provider]
107 |
108 | _, model_list = get_available_models_and_providers()
109 |
110 | # if model not in model_list:
111 | # logger.warning(
112 | # f"Model {model} not in known models for {model_provider}. "
113 | # f"Known models: {provider_info['models']}. Proceeding anyway..."
114 | # )
115 |
116 | client_class = provider_info["client_class"]
117 | return client_class(model, model_provider, **model_kwargs)
118 |
119 |
120 | def parse_model_args(model_args_str: str | None) -> dict[str, Any]:
121 | """Parse model arguments string into a dictionary.
122 |
123 | Args:
124 | model_args_str: Comma-separated string of key=value pairs
125 |
126 | Returns:
127 | Dictionary of parsed arguments
128 |
129 | Example:
130 | "top_p=0.95,temperature=0.70" -> {"top_p": 0.95, "temperature": 0.70}
131 | """
132 | if not model_args_str:
133 | return {}
134 |
135 | args = {}
136 | for pair in model_args_str.split(","):
137 | if "=" not in pair:
138 | logger.warning(f"Skipping invalid model argument: {pair}")
139 | continue
140 |
141 | key, value = pair.split("=", 1)
142 | key = key.strip()
143 | value = value.strip()
144 |
145 | # Try to convert to appropriate type
146 | try:
147 | # Try int first
148 | if value.isdigit() or (value.startswith("-") and value[1:].isdigit()):
149 | args[key] = int(value)
150 | # Try float
151 | elif "." in value and value.replace(".", "").replace("-", "").isdigit():
152 | args[key] = float(value)
153 | # Try boolean
154 | elif value.lower() in ("true", "false"):
155 | args[key] = value.lower() == "true"
156 | # Keep as string
157 | else:
158 | args[key] = value
159 | except ValueError:
160 | args[key] = value
161 |
162 | return args
163 |
164 |
165 | def get_model_info(model_name: str) -> tuple[str, int]:
166 | """Get normalized model name and max input tokens for a given model.
167 |
168 | Args:
169 | model_name: Model name that might be from different providers
170 | (e.g., 'us.anthropic.claude-sonnet-4-20250514-v1:0')
171 |
172 | Returns:
173 | Tuple of (normalized_model_name, max_input_tokens)
174 |
175 | Raises:
176 | ValueError: If model cannot be found
177 | """
178 | # Normalize model name by checking for known patterns
179 | normalized_name = None
180 | max_tokens = None
181 |
182 | # Check each provider's models
183 | for provider_info in LLM_CLIENT_MAP.values():
184 | for model_info in provider_info["models"]:
185 | model_canonical_name = model_info["name"]
186 |
187 | # Check if the canonical name is contained in the input model name
188 | if model_canonical_name in model_name:
189 | normalized_name = model_canonical_name
190 | max_tokens = model_info["max_input_tokens"]
191 | break
192 |
193 | # Also check for partial matches (e.g., "claude-sonnet-4" in "claude-sonnet-4-20250514")
194 | if "-" in model_canonical_name:
195 | parts = model_canonical_name.split("-")
196 | # Try different combinations of parts
197 | for i in range(len(parts), 0, -1):
198 | partial = "-".join(parts[:i])
199 | if (
200 | partial in model_name and len(partial) > 3
201 | ): # Avoid too short matches
202 | normalized_name = model_canonical_name
203 | max_tokens = model_info["max_input_tokens"]
204 | break
205 |
206 | if normalized_name:
207 | break
208 |
209 | if not normalized_name:
210 | # If no exact match found, log warning and return defaults
211 | logger.warning(
212 | f"Could not find exact match for model '{model_name}' in LLM_CLIENT_MAP. "
213 | "Using default token limit."
214 | )
215 | # Return the original model name with a conservative default
216 | return model_name, 128000 # Conservative default
217 |
218 | return normalized_name, max_tokens
219 |
--------------------------------------------------------------------------------
/src/swe_care/inference/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from pathlib import Path
4 |
5 | from loguru import logger
6 |
7 | import swe_care.inference.create_code_review_text
8 | import swe_care.inference.run_api
9 | from swe_care.inference.create_code_review_text import create_code_review_text
10 | from swe_care.inference.run_api import run_api
11 | from swe_care.utils.llm_models import get_available_models_and_providers
12 |
13 | # Mapping of subcommands to their function names
14 | SUBCOMMAND_MAP = {
15 | "create_code_review_text": {
16 | "function": create_code_review_text,
17 | "help": swe_care.inference.create_code_review_text.__doc__,
18 | },
19 | "run_api": {
20 | "function": run_api,
21 | "help": swe_care.inference.run_api.__doc__,
22 | },
23 | }
24 |
25 |
26 | def create_global_parser():
27 | """Create a parser with global arguments that can be used as a parent parser."""
28 | global_parser = argparse.ArgumentParser(add_help=False)
29 | global_parser.add_argument(
30 | "--output-dir",
31 | type=Path,
32 | required=True,
33 | help="Path to output directory",
34 | )
35 | return global_parser
36 |
37 |
38 | def get_args():
39 | # Parse command line manually to handle flexible argument order
40 | args = sys.argv[1:]
41 |
42 | # Find the subcommand
43 | subcommands = list(SUBCOMMAND_MAP.keys())
44 | subcommand = None
45 | subcommand_index = None
46 |
47 | for i, arg in enumerate(args):
48 | if arg in subcommands:
49 | subcommand = arg
50 | subcommand_index = i
51 | break
52 |
53 | # Create global parser
54 | global_parser = create_global_parser()
55 |
56 | if subcommand is None:
57 | # No subcommand found, use normal argparse
58 | parser = argparse.ArgumentParser(
59 | prog="swe_care.inference",
60 | description="Inference tools for SWE-CARE",
61 | parents=[global_parser],
62 | )
63 |
64 | subparsers = parser.add_subparsers(dest="command", help="Available commands")
65 | for cmd, info in SUBCOMMAND_MAP.items():
66 | subparsers.add_parser(cmd, help=info["help"])
67 |
68 | return parser.parse_args(args)
69 |
70 | # Create the appropriate subcommand parser with global parser as parent
71 | match subcommand:
72 | case "create_code_review_text":
73 | sub_parser = argparse.ArgumentParser(
74 | prog=f"swe_care.inference {subcommand}",
75 | parents=[global_parser],
76 | description=SUBCOMMAND_MAP[subcommand]["help"],
77 | )
78 | sub_parser.add_argument(
79 | "--dataset-name-or-path",
80 | type=str,
81 | required=False,
82 | default="inclusionAI/SWE-CARE",
83 | help="Path to the input SWE-CARE dataset file, or Hugging Face dataset name (default: inclusionAI/SWE-CARE)",
84 | )
85 | sub_parser.add_argument(
86 | "--file-source",
87 | type=str,
88 | choices=["none", "oracle", "bm25", "all"],
89 | required=True,
90 | help="Source strategy for files: 'none' (no files), 'oracle' (ground truth), 'bm25' (retrieval), or 'all' (all available)",
91 | )
92 | sub_parser.add_argument(
93 | "--k",
94 | type=int,
95 | required=False,
96 | default=None,
97 | help="Maximum number of files to use (required when --file-source is 'bm25' or 'all')",
98 | )
99 | sub_parser.add_argument(
100 | "--retrieval-output-dir",
101 | type=Path,
102 | default=None,
103 | help="Output directory for retrieval operations (required when --file-source is 'bm25' or 'all')",
104 | )
105 | sub_parser.add_argument(
106 | "--tokens",
107 | type=str,
108 | nargs="*",
109 | default=None,
110 | help="GitHub API token(s) to be used randomly for fetching data",
111 | )
112 | sub_parser.add_argument(
113 | "--jobs",
114 | type=int,
115 | default=2,
116 | help="Number of parallel jobs for multithreaded processing (default: 2)",
117 | )
118 | sub_parser.add_argument(
119 | "--skip-existing",
120 | action="store_true",
121 | default=False,
122 | help="Skip existing instances in the output file based on instance_id (default: False)",
123 | )
124 | sub_parser.add_argument(
125 | "--use-skeleton",
126 | action="store_true",
127 | default=False,
128 | help="Use TreeSitter-based Python stubs for file contents (default: False)",
129 | )
130 |
131 | case "run_api":
132 | sub_parser = argparse.ArgumentParser(
133 | prog=f"swe_care.inference {subcommand}",
134 | parents=[global_parser],
135 | description=SUBCOMMAND_MAP[subcommand]["help"],
136 | )
137 | sub_parser.add_argument(
138 | "--dataset-file",
139 | type=Path,
140 | required=True,
141 | help="Path to the input dataset file containing CodeReviewInferenceInstance objects",
142 | )
143 |
144 | available_providers, available_models = get_available_models_and_providers()
145 |
146 | sub_parser.add_argument(
147 | "--model",
148 | type=str,
149 | required=True,
150 | help=f"Model name to use for inference. Available models: {', '.join(available_models)}",
151 | )
152 | sub_parser.add_argument(
153 | "--model-provider",
154 | type=str,
155 | required=True,
156 | choices=available_providers,
157 | default="openai" if "openai" in available_providers else None,
158 | help=f"Model provider. Available providers: {', '.join(available_providers)}",
159 | )
160 | sub_parser.add_argument(
161 | "--model-args",
162 | type=str,
163 | required=False,
164 | default=None,
165 | help="List of model arguments separated by commas (e.g., 'top_p=0.95,temperature=0.70')",
166 | )
167 | sub_parser.add_argument(
168 | "--jobs",
169 | type=int,
170 | default=2,
171 | help="Number of parallel jobs for multithreaded inference (default: 2)",
172 | )
173 | sub_parser.add_argument(
174 | "--skip-existing",
175 | action="store_true",
176 | default=False,
177 | help="Whether to skip existing predictions in the output file (default: False)",
178 | )
179 |
180 | # Parse all arguments with the subcommand parser
181 | # This will include both global and subcommand-specific arguments
182 | # Remove the subcommand itself from args
183 | args_without_subcommand = args[:subcommand_index] + args[subcommand_index + 1 :]
184 | final_namespace = sub_parser.parse_args(args_without_subcommand)
185 | final_namespace.command = subcommand
186 |
187 | return final_namespace
188 |
189 |
190 | def main():
191 | args = get_args()
192 |
193 | if args.command in SUBCOMMAND_MAP:
194 | # Get the function from the mapping
195 | cmd_info = SUBCOMMAND_MAP[args.command]
196 | function = cmd_info["function"]
197 |
198 | # Prepare common arguments
199 | common_kwargs = {"output_dir": args.output_dir}
200 |
201 | # Add specific arguments based on subcommand
202 | match args.command:
203 | case "create_code_review_text":
204 | function(
205 | dataset_name_or_path=args.dataset_name_or_path,
206 | file_source=args.file_source,
207 | k=args.k,
208 | retrieval_output_dir=args.retrieval_output_dir,
209 | tokens=args.tokens,
210 | jobs=args.jobs,
211 | skip_existing=args.skip_existing,
212 | use_skeleton=args.use_skeleton,
213 | **common_kwargs,
214 | )
215 | case "run_api":
216 | function(
217 | dataset_file=args.dataset_file,
218 | model=args.model,
219 | model_provider=args.model_provider,
220 | model_args=args.model_args,
221 | jobs=args.jobs,
222 | skip_existing=args.skip_existing,
223 | **common_kwargs,
224 | )
225 | else:
226 | logger.info("Please specify a command. Use --help for available commands.")
227 |
228 |
229 | if __name__ == "__main__":
230 | main()
231 |
--------------------------------------------------------------------------------
/src/swe_care/utils/estimate.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from loguru import logger
4 | from tenacity import retry, retry_if_exception_type, stop_after_attempt
5 |
6 | from swe_care.schema.dataset import ReferenceReviewComment
7 | from swe_care.utils.llm_models.clients import BaseModelClient
8 | from swe_care.utils.prompt_loader import load_prompt
9 |
10 |
11 | class InvalidResponseError(Exception):
12 | """Raised when LLM returns an invalid response that should trigger a retry."""
13 |
14 | pass
15 |
16 |
17 | @retry(
18 | stop=stop_after_attempt(3),
19 | retry=retry_if_exception_type(InvalidResponseError),
20 | reraise=True,
21 | )
22 | def classify_problem_domain(client: BaseModelClient, problem_statement: str) -> str:
23 | """Classify problem domain of a code review task."""
24 | valid_categories = {
25 | "Bug Fixes",
26 | "New Feature Additions",
27 | "Code Refactoring / Architectural Improvement",
28 | "Documentation Updates",
29 | "Test Suite / CI Enhancements",
30 | "Performance Optimizations",
31 | "Security Patches / Vulnerability Fixes",
32 | "Dependency Updates & Env Compatibility",
33 | "Code Style, Linting, Formatting Fixes",
34 | }
35 |
36 | system_prompt, user_prompt = load_prompt(
37 | "classify_problem_domain", problem_statement=problem_statement
38 | )
39 |
40 | messages = [
41 | {"role": "system", "content": system_prompt},
42 | {"role": "user", "content": user_prompt},
43 | ]
44 |
45 | try:
46 | response = client.create_completion(messages).strip()
47 | if response in valid_categories:
48 | return response
49 | else:
50 | logger.warning(f"Invalid category response: '{response}'")
51 | raise InvalidResponseError(f"Invalid response: '{response}'")
52 | except Exception as e:
53 | if isinstance(e, InvalidResponseError):
54 | raise
55 | logger.warning(f"Error in classify_problem_domain: {e}")
56 | raise InvalidResponseError(f"LLM call failed: {e}")
57 |
58 |
59 | @retry(
60 | stop=stop_after_attempt(3),
61 | retry=retry_if_exception_type(InvalidResponseError),
62 | reraise=True,
63 | )
64 | def estimate_difficulty(
65 | client: BaseModelClient, pr_data: dict[str, Any], commit_message: str, patch: str
66 | ) -> str:
67 | """Estimate the implementation difficulty for a pull request task."""
68 |
69 | title = pr_data.get("title", "")
70 | body = pr_data.get("body", "")
71 | changed_files = pr_data.get("changedFiles", "")
72 |
73 | system_prompt, user_prompt = load_prompt(
74 | "estimate_difficulty",
75 | title=title,
76 | body=body,
77 | commit_message=commit_message,
78 | changed_files=str(changed_files),
79 | patch=patch,
80 | )
81 |
82 | messages = [
83 | {"role": "system", "content": system_prompt},
84 | {"role": "user", "content": user_prompt},
85 | ]
86 |
87 | try:
88 | response = client.create_completion(messages).strip()
89 | difficulty = int(response)
90 | if difficulty in [1, 2, 3]:
91 | if difficulty == 1:
92 | return "low"
93 | elif difficulty == 2:
94 | return "medium"
95 | elif difficulty == 3:
96 | return "high"
97 | else:
98 | logger.warning(f"Invalid difficulty response: '{response}'")
99 | raise InvalidResponseError(
100 | f"Invalid response: '{response}' (expected 1, 2, or 3)"
101 | )
102 | except ValueError:
103 | logger.warning(f"Could not parse difficulty response: {response}")
104 | raise InvalidResponseError(f"Could not parse response: '{response}'")
105 | except Exception as e:
106 | if isinstance(e, InvalidResponseError):
107 | raise
108 | logger.warning(f"Error in estimate_difficulty: {e}")
109 | raise InvalidResponseError(f"LLM call failed: {e}")
110 |
111 |
112 | def classify_review_effort(
113 | client: BaseModelClient, pr_data: dict[str, Any], commit_message: str, patch: str
114 | ) -> int:
115 | """Estimate review effort for a code review task."""
116 | labels = pr_data.get("labels", [])
117 | if "nodes" in labels:
118 | for node in labels["nodes"]:
119 | if "Review effort" in node.get("name"):
120 | try:
121 | effort = node.get("name").split("Review effort")[-1].strip()
122 | if "/" in effort:
123 | effort = effort.split("/")[0]
124 | elif ":" in effort:
125 | effort = effort.split(":")[-1].strip()
126 | return int(effort)
127 | except Exception as e:
128 | logger.warning(
129 | f"Error parsing review effort, fallback to LLM estimate: {e}"
130 | )
131 | break
132 |
133 | title = pr_data.get("title", "")
134 | body = pr_data.get("body", "")
135 |
136 | system_prompt, user_prompt = load_prompt(
137 | "classify_review_effort",
138 | title=title,
139 | body=body,
140 | commit_message=commit_message,
141 | patch=patch,
142 | )
143 |
144 | @retry(
145 | stop=stop_after_attempt(3),
146 | retry=retry_if_exception_type(InvalidResponseError),
147 | reraise=True,
148 | )
149 | def _llm_call():
150 | messages = [
151 | {"role": "system", "content": system_prompt},
152 | {"role": "user", "content": user_prompt},
153 | ]
154 |
155 | try:
156 | response = client.create_completion(messages).strip()
157 | effort = int(response)
158 | if effort in [1, 2, 3, 4, 5]:
159 | return effort
160 | else:
161 | logger.warning(f"Invalid effort response: '{response}'")
162 | raise InvalidResponseError(
163 | f"Invalid response: '{response}' (expected 1-5)"
164 | )
165 | except ValueError:
166 | logger.warning(f"Could not parse effort response: {response}")
167 | raise InvalidResponseError(f"Could not parse response: '{response}'")
168 | except Exception as e:
169 | if isinstance(e, InvalidResponseError):
170 | raise
171 | logger.warning(f"Error in classify_review_effort: {e}")
172 | raise InvalidResponseError(f"LLM call failed: {e}")
173 |
174 | try:
175 | return _llm_call()
176 | except InvalidResponseError:
177 | raise ValueError(
178 | "Failed to get valid review effort estimate (1-5) after 3 attempts"
179 | )
180 |
181 |
182 | @retry(
183 | stop=stop_after_attempt(3),
184 | retry=retry_if_exception_type(InvalidResponseError),
185 | reraise=True,
186 | )
187 | def classify_relevant_review_comment(
188 | client: BaseModelClient, review_comment: ReferenceReviewComment
189 | ) -> bool:
190 | """Classify if a review comment is relevant (likely to lead to code changes).
191 |
192 | Args:
193 | client: The LLM client to use for classification
194 | review_comment: The review comment to classify
195 |
196 | Returns:
197 | True if the comment is relevant (likely to lead to code changes),
198 | False if irrelevant (unlikely to lead to code changes)
199 | """
200 |
201 | # Determine the line number to check for line change detection
202 | line_to_check = None
203 | if review_comment.original_line is not None:
204 | line_to_check = review_comment.original_line
205 | elif review_comment.line is not None:
206 | line_to_check = review_comment.line
207 | elif review_comment.start_line is not None:
208 | line_to_check = review_comment.start_line
209 | elif review_comment.original_start_line is not None:
210 | line_to_check = review_comment.original_start_line
211 |
212 | system_prompt, user_prompt = load_prompt(
213 | "classify_relevant_review_comment",
214 | comment_text=review_comment.text,
215 | diff_hunk=review_comment.diff_hunk,
216 | path=review_comment.path,
217 | line=line_to_check,
218 | )
219 |
220 | messages = [
221 | {"role": "system", "content": system_prompt},
222 | {"role": "user", "content": user_prompt},
223 | ]
224 |
225 | try:
226 | response = client.create_completion(messages).strip().lower()
227 | if response == "relevant":
228 | return True
229 | elif response == "irrelevant":
230 | return False
231 | else:
232 | logger.warning(
233 | f"Invalid relevance classification response: '{response}' "
234 | f"(expected 'relevant' or 'irrelevant')"
235 | )
236 | raise InvalidResponseError(
237 | f"Invalid response: '{response}' (expected 'relevant' or 'irrelevant')"
238 | )
239 | except Exception as e:
240 | if isinstance(e, InvalidResponseError):
241 | raise
242 | logger.warning(f"Error in classify_relevant_review_comment: {e}")
243 | raise InvalidResponseError(f"LLM call failed: {e}")
244 |
--------------------------------------------------------------------------------
/src/swe_care/harness/code_review_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Run evaluation on code review predictions.
3 | """
4 |
5 | import threading
6 | from concurrent.futures import ThreadPoolExecutor, as_completed
7 | from datetime import datetime
8 | from enum import Enum
9 | from pathlib import Path
10 | from typing import Any, Optional
11 |
12 | from loguru import logger
13 | from tqdm import tqdm
14 |
15 | from swe_care.harness.evaluators import Evaluator
16 | from swe_care.harness.evaluators.code_review import (
17 | LLMEvaluator,
18 | RuleBasedEvaluator,
19 | )
20 | from swe_care.harness.evaluators.repo_level import (
21 | RepoLevelLLMEvaluator,
22 | )
23 | from swe_care.schema.evaluation import (
24 | CodeReviewEvaluationResult,
25 | EvaluatorResult,
26 | )
27 | from swe_care.utils.llm_models import init_llm_client, parse_model_args
28 | from swe_care.utils.llm_models.clients import BaseModelClient
29 | from swe_care.utils.load import load_code_review_dataset, load_code_review_predictions
30 |
31 |
32 | class EvaluatorType(str, Enum):
33 | """The types of the evaluators."""
34 |
35 | LLM_EVALUATOR = "llm_evaluator"
36 | """The LLM evaluator."""
37 |
38 | RULE_BASED_EVALUATOR = "rule_based_evaluator"
39 | """The rule-based evaluator."""
40 |
41 | # REPO_LEVEL_LLM_EVALUATOR = "repo_level_llm_evaluator"
42 | # """The repo-level LLM evaluator."""
43 |
44 |
45 | _EVALUATOR_MAP: dict[EvaluatorType, type[Evaluator]] = {
46 | EvaluatorType.LLM_EVALUATOR: LLMEvaluator,
47 | EvaluatorType.RULE_BASED_EVALUATOR: RuleBasedEvaluator,
48 | # EvaluatorType.REPO_LEVEL_LLM_EVALUATOR: RepoLevelLLMEvaluator,
49 | }
50 |
51 |
52 | def load_evaluator(
53 | evaluator_type: EvaluatorType,
54 | *,
55 | model_client: Optional[BaseModelClient] = None,
56 | **kwargs: Any,
57 | ) -> Evaluator:
58 | """Load the evaluator based on the type."""
59 | if evaluator_type not in _EVALUATOR_MAP:
60 | raise ValueError(
61 | f"Unknown evaluator type: {evaluator_type}"
62 | f"\nValid types are: {list(_EVALUATOR_MAP.keys())}"
63 | )
64 | evaluator_cls = _EVALUATOR_MAP[evaluator_type]
65 | if issubclass(evaluator_cls, (LLMEvaluator, RepoLevelLLMEvaluator)):
66 | if model_client is None:
67 | raise ValueError("LLM model client is required for LLM evaluator")
68 | evaluator = evaluator_cls(model_client=model_client, **kwargs)
69 | else:
70 | evaluator = evaluator_cls(**kwargs)
71 |
72 | logger.info(f"Loaded evaluator {evaluator_type} with kwargs: {kwargs}")
73 | return evaluator
74 |
75 |
76 | def code_review_eval_instance(
77 | instance,
78 | predictions: list,
79 | evaluators: list[Evaluator],
80 | ) -> CodeReviewEvaluationResult:
81 | """Process a single instance and return the evaluation result."""
82 | prediction = [p for p in predictions if p.instance_id == instance.instance_id]
83 | if not prediction:
84 | raise ValueError(f"No prediction found for instance {instance.instance_id}")
85 | prediction = prediction[0]
86 |
87 | evaluation_results: list[EvaluatorResult] = []
88 | for evaluator in evaluators:
89 | evaluation = None
90 | try:
91 | evaluation = evaluator.evaluate(
92 | prediction=prediction,
93 | reference=instance.reference_review_comments,
94 | input=instance,
95 | )
96 |
97 | except Exception as e:
98 | raise ValueError(
99 | f"Error evaluating instance {instance.instance_id} with {evaluator.evaluation_name}: {e}"
100 | )
101 |
102 | evaluation_results.append(
103 | EvaluatorResult(
104 | evaluator=evaluator.evaluation_name,
105 | evaluation=evaluation,
106 | )
107 | )
108 |
109 | score = sum(
110 | [
111 | evaluation.evaluation.get("score", 0)
112 | for evaluation in evaluation_results
113 | if evaluation.evaluation.get("score", 0) is not None
114 | ]
115 | ) / len(evaluation_results)
116 |
117 | evaluation_result = CodeReviewEvaluationResult(
118 | instance_id=instance.instance_id,
119 | score=score,
120 | evaluations=evaluation_results,
121 | )
122 |
123 | return evaluation_result
124 |
125 |
126 | def code_review_eval(
127 | predictions_path: Path | str,
128 | output_dir: Path | str,
129 | evaluator_types: list[EvaluatorType],
130 | dataset_name_or_path: Path | str = "inclusionAI/SWE-CARE",
131 | model: Optional[str] = None,
132 | model_provider: Optional[str] = None,
133 | model_args: Optional[str] = None,
134 | evaluator_kwargs: Optional[dict[str, dict[str, Any]]] = None,
135 | jobs: int = 2,
136 | ) -> None:
137 | """
138 | Run evaluation on code review predictions.
139 |
140 | Args:
141 | predictions_path: Path to predictions file or directory containing predictions
142 | output_dir: Directory where the final_report.json will be saved
143 | evaluator_types: List of evaluator types to use
144 | dataset_name_or_path: Path to the dataset file or Hugging Face dataset name (default: inclusionAI/SWE-CARE)
145 | model: Model name to use for LLM evaluation (required if using LLM evaluator)
146 | model_provider: Model provider (required if using LLM evaluator)
147 | model_args: Comma-separated model arguments
148 | evaluator_kwargs: Dict mapping evaluator types to their kwargs
149 | jobs: Number of parallel jobs to run (default: 2)
150 | """
151 | if isinstance(predictions_path, str):
152 | predictions_path = Path(predictions_path)
153 | if isinstance(output_dir, str):
154 | output_dir = Path(output_dir)
155 |
156 | instances = load_code_review_dataset(dataset_name_or_path)
157 | predictions = load_code_review_predictions(predictions_path)
158 |
159 | # Initialize LLM client if needed
160 | model_client = None
161 | llm_evaluator_types = [
162 | EvaluatorType.LLM_EVALUATOR,
163 | # EvaluatorType.REPO_LEVEL_LLM_EVALUATOR,
164 | ]
165 | if any(et in evaluator_types for et in llm_evaluator_types):
166 | if not model or not model_provider:
167 | raise ValueError("Model and model provider are required for LLM evaluator")
168 |
169 | model_kwargs = parse_model_args(model_args)
170 | model_client = init_llm_client(model, model_provider, **model_kwargs)
171 | logger.info(
172 | f"Initialized {model_provider} client with model {model} using model arguments: {model_kwargs}"
173 | )
174 |
175 | # Initialize evaluator_kwargs if not provided
176 | if evaluator_kwargs is None:
177 | evaluator_kwargs = {}
178 |
179 | evaluators = []
180 | for evaluator_type in evaluator_types:
181 | # Get kwargs for this specific evaluator type
182 | kwargs = evaluator_kwargs.get(evaluator_type.value, {})
183 | evaluator = load_evaluator(
184 | evaluator_type,
185 | model_client=model_client,
186 | **kwargs,
187 | )
188 | evaluators.append(evaluator)
189 |
190 | # Create output directory if it doesn't exist
191 | output_dir.mkdir(parents=True, exist_ok=True)
192 | output_file = (
193 | output_dir
194 | / f"{predictions_path.stem}_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
195 | )
196 |
197 | # Thread-safe file writing
198 | write_lock = threading.Lock()
199 |
200 | # Initialize the output file (truncate if exists)
201 | with open(output_file, "w"):
202 | pass # Just create/truncate the file
203 |
204 | # Counters for tracking progress
205 | successful_evaluations = 0
206 | failed_evaluations = 0
207 |
208 | with ThreadPoolExecutor(max_workers=jobs) as executor:
209 | # Submit all tasks
210 | future_to_instance = {
211 | executor.submit(
212 | code_review_eval_instance,
213 | instance=instance,
214 | predictions=predictions,
215 | evaluators=evaluators,
216 | ): instance
217 | for instance in instances
218 | }
219 |
220 | # Process completed tasks with progress bar
221 | with tqdm(
222 | total=len(instances),
223 | desc=f"Evaluating instances with [{', '.join([e.value for e in evaluator_types])}] ({jobs} threads)",
224 | ) as pbar:
225 | for future in as_completed(future_to_instance):
226 | instance = future_to_instance[future]
227 |
228 | try:
229 | result = future.result()
230 | with write_lock:
231 | with open(output_file, "a") as f:
232 | f.write(result.to_json() + "\n")
233 | successful_evaluations += 1
234 |
235 | except Exception as e:
236 | failed_evaluations += 1
237 | logger.error(
238 | f"Exception evaluating instance {instance.instance_id}: {e}"
239 | )
240 |
241 | pbar.update(1)
242 | pbar.set_postfix(
243 | {
244 | "success": successful_evaluations,
245 | "failed": failed_evaluations,
246 | }
247 | )
248 |
249 | logger.info(
250 | f"Evaluation completed. Results saved to {output_file}. "
251 | f"Success: {successful_evaluations}, Failed: {failed_evaluations}"
252 | )
253 |
--------------------------------------------------------------------------------
/src/swe_care/utils/github.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import time
4 | import urllib.parse
5 | from typing import Any, Optional
6 |
7 | import requests
8 | from loguru import logger
9 | from tenacity import (
10 | after_log,
11 | before_sleep_log,
12 | retry,
13 | retry_if_exception_type,
14 | stop_after_attempt,
15 | wait_exponential,
16 | )
17 |
18 |
19 | class MaxNodeLimitExceededError(Exception):
20 | """Exception raised when the maximum node limit is exceeded."""
21 |
22 | def __init__(self, error: dict):
23 | self.error = error
24 | super().__init__(f"Max node limit exceeded: {error}")
25 |
26 |
27 | class GitHubAPI:
28 | """GitHub API client with rate limiting, retries, and proper error handling."""
29 |
30 | def __init__(
31 | self,
32 | max_retries: int = 5,
33 | timeout: int = 60,
34 | tokens: Optional[list[str]] = None,
35 | ):
36 | """
37 | Initialize GitHub API client.
38 |
39 | Args:
40 | max_retries: Maximum number of retries for API calls
41 | tokens: Optional list of GitHub tokens for authentication
42 | """
43 | self.max_retries = max_retries
44 | self.tokens = tokens or []
45 | self.timeout = timeout
46 | self.graphql_endpoint = "https://api.github.com/graphql"
47 | self.rest_api_base = "https://api.github.com"
48 |
49 | def _get_token(self) -> Optional[str]:
50 | """Get a randomly selected token from the available tokens."""
51 | if not self.tokens:
52 | return None
53 | return random.choice(self.tokens)
54 |
55 | def _get_headers(self, content_type: str = "application/json") -> dict[str, str]:
56 | """Get headers for API requests."""
57 | headers = {
58 | "Content-Type": content_type,
59 | "Accept": "application/vnd.github+json",
60 | }
61 |
62 | token = self._get_token()
63 | if token:
64 | headers["Authorization"] = f"Bearer {token}"
65 |
66 | return headers
67 |
68 | def _handle_rate_limit(self, response: requests.Response) -> None:
69 | """Handle rate limiting based on response headers."""
70 | if "X-RateLimit-Remaining" in response.headers:
71 | remaining = int(response.headers["X-RateLimit-Remaining"])
72 | if remaining < 10: # If less than 10 requests remaining
73 | reset_time = int(response.headers.get("X-RateLimit-Reset", 0))
74 | current_time = int(time.time())
75 | wait_time = max(0, reset_time - current_time)
76 | if wait_time > 0:
77 | logger.info(
78 | f"Rate limit approaching. Waiting {wait_time} seconds..."
79 | )
80 | time.sleep(wait_time)
81 |
82 | def _retry_wrapper(self, func):
83 | """Create a retry wrapper using the instance's max_retries setting."""
84 | return retry(
85 | reraise=True,
86 | stop=stop_after_attempt(self.max_retries),
87 | wait=wait_exponential(multiplier=2, min=4, max=60),
88 | retry=retry_if_exception_type(
89 | (
90 | requests.exceptions.RequestException,
91 | requests.exceptions.HTTPError,
92 | )
93 | ),
94 | before_sleep=before_sleep_log(logger, "WARNING"),
95 | after=after_log(logger, "WARNING"),
96 | )(func)
97 |
98 | def execute_graphql_query(
99 | self, query: str, variables: dict[str, Any]
100 | ) -> dict[str, Any]:
101 | """
102 | Execute a GraphQL query with retry mechanism.
103 |
104 | Args:
105 | query: GraphQL query string
106 | variables: Variables for the GraphQL query
107 |
108 | Returns:
109 | JSON response from the GraphQL API
110 |
111 | Raises:
112 | ValueError: If GraphQL errors are returned
113 | requests.exceptions.RequestException: If the request fails
114 | """
115 |
116 | def _execute_query():
117 | headers = self._get_headers()
118 | payload = {
119 | "query": query,
120 | "variables": variables,
121 | }
122 |
123 | response = requests.post(
124 | self.graphql_endpoint,
125 | headers=headers,
126 | data=json.dumps(payload),
127 | timeout=self.timeout,
128 | )
129 | response.raise_for_status()
130 |
131 | # Handle rate limiting
132 | self._handle_rate_limit(response)
133 |
134 | results = response.json()
135 |
136 | # Check for GraphQL errors
137 | if "errors" in results:
138 | # Check if it's a node limit error that we can handle
139 | for error in results["errors"]:
140 | if error.get("type") == "MAX_NODE_LIMIT_EXCEEDED":
141 | raise MaxNodeLimitExceededError(error)
142 |
143 | raise ValueError(f"GraphQL errors: {results['errors']}")
144 |
145 | return results
146 |
147 | return self._retry_wrapper(_execute_query)()
148 |
149 | def call_api(
150 | self,
151 | url: str,
152 | method: str = "GET",
153 | params: Optional[dict[str, Any]] = None,
154 | data: Optional[dict[str, Any]] = None,
155 | ) -> requests.Response:
156 | """
157 | Call GitHub REST API endpoint with retry mechanism.
158 |
159 | Args:
160 | url: Full URL or path relative to GitHub API base
161 | method: HTTP method (GET, POST, etc.)
162 | params: Query parameters
163 | data: Request body data
164 | timeout: Request timeout in seconds
165 |
166 | Returns:
167 | Response object
168 |
169 | Raises:
170 | requests.exceptions.RequestException: If the request fails
171 | """
172 |
173 | def _call_api():
174 | # Handle both full URLs and relative paths
175 | full_url = url
176 | if not url.startswith("http"):
177 | full_url = f"{self.rest_api_base}/{url.lstrip('/')}"
178 |
179 | headers = self._get_headers()
180 |
181 | response = requests.request(
182 | method=method,
183 | url=full_url,
184 | headers=headers,
185 | params=params,
186 | json=data,
187 | timeout=self.timeout,
188 | )
189 | response.raise_for_status()
190 |
191 | # Handle rate limiting
192 | self._handle_rate_limit(response)
193 |
194 | return response
195 |
196 | return self._retry_wrapper(_call_api)()
197 |
198 | def get_patch(
199 | self,
200 | repo: str,
201 | *,
202 | pr_number: Optional[int] = None,
203 | base_commit: Optional[str] = None,
204 | head_commit: Optional[str] = None,
205 | ) -> str:
206 | """
207 | Get patch/diff for a PR or commit range.
208 |
209 | Args:
210 | repo: Repository in format 'owner/repo'
211 | pr_number: Pull request number (for PR patch)
212 | base_commit: Base commit for commit range
213 | head_commit: Head commit for commit range
214 |
215 | Returns:
216 | Patch/diff content as string
217 |
218 | Raises:
219 | ValueError: If neither pr_number nor commit range is provided
220 | requests.exceptions.RequestException: If the request fails
221 | """
222 |
223 | def _get_patch():
224 | if pr_number is not None:
225 | patch_url = f"https://github.com/{repo}/pull/{pr_number}.diff"
226 | elif base_commit and head_commit:
227 | patch_url = f"https://github.com/{repo}/compare/{base_commit}...{head_commit}.diff"
228 | else:
229 | raise ValueError(
230 | "Either pr_number or both base_commit and head_commit must be provided"
231 | )
232 |
233 | headers = self._get_headers(content_type="text/plain")
234 |
235 | response = requests.get(patch_url, headers=headers, timeout=self.timeout)
236 | response.raise_for_status()
237 |
238 | # Handle rate limiting
239 | self._handle_rate_limit(response)
240 |
241 | return response.text
242 |
243 | return self._retry_wrapper(_get_patch)()
244 |
245 | def get_file_content(self, repo: str, commit: str, file_path: str) -> str:
246 | """
247 | Get the content of a file at a specific commit.
248 |
249 | Args:
250 | repo: Repository in format 'owner/repo'
251 | commit: Commit SHA to fetch the file from
252 | file_path: Path to the file in the repository
253 |
254 | Returns:
255 | File content as string
256 |
257 | Raises:
258 | requests.exceptions.RequestException: If the request fails
259 | """
260 |
261 | encoded_path = urllib.parse.quote(file_path, safe="")
262 | content_response = self.call_api(
263 | f"repos/{repo}/contents/{encoded_path}", params={"ref": commit}
264 | )
265 | content_data = content_response.json()
266 |
267 | # Decode base64 content
268 | if "content" in content_data and content_data.get("encoding") == "base64":
269 | import base64
270 |
271 | content = base64.b64decode(content_data["content"]).decode("utf-8")
272 | return content
273 | else:
274 | logger.warning(f"Unable to decode content for {file_path}")
275 | return ""
276 |
--------------------------------------------------------------------------------
/src/swe_care/harness/evaluators/repo_level.py:
--------------------------------------------------------------------------------
1 | """
2 | Repo-level LLM evaluator for code review predictions.
3 | """
4 |
5 | import json
6 | import re
7 | from typing import Any, Literal, Optional
8 |
9 | from loguru import logger
10 |
11 | from swe_care.harness.evaluators import Evaluator
12 | from swe_care.harness.evaluators.utils import extract_defects_from_review
13 | from swe_care.schema.dataset import (
14 | CodeReviewTaskInstance,
15 | ReferenceReviewComment,
16 | )
17 | from swe_care.schema.evaluation import CodeReviewPrediction
18 | from swe_care.utils.file_source_retrieval import get_relevant_files
19 | from swe_care.utils.llm_models.clients import BaseModelClient
20 | from swe_care.utils.prompt_loader import load_prompt
21 |
22 |
23 | class RepoLevelLLMEvaluator(Evaluator):
24 | """Evaluator that uses LLM to classify review comments as positive or negative."""
25 |
26 | def __init__(
27 | self,
28 | model_client: BaseModelClient,
29 | file_source: Literal[
30 | "none",
31 | "base_changed_files",
32 | "reviewed_file",
33 | "retrieved_base_changed_files",
34 | "retrieved_all_files",
35 | ] = "none",
36 | retrieval_max_files: int = 5,
37 | retrieval_output_dir: Optional[str] = None,
38 | tokens: Optional[list[str]] = None,
39 | **kwargs,
40 | ):
41 | """Initialize the Repo-level LLM evaluator.
42 |
43 | Args:
44 | model_client: The LLM client to use for evaluation
45 | file_source: Source for file content
46 | retrieval_max_files: Maximum number of files for retrieval
47 | retrieval_output_dir: Output directory for retrieval operations
48 | tokens: GitHub API tokens
49 | """
50 | super().__init__(**kwargs)
51 | self.model_client = model_client
52 | self.file_source = file_source
53 | self.retrieval_max_files = retrieval_max_files
54 | self.retrieval_output_dir = retrieval_output_dir
55 | self.tokens = tokens
56 |
57 | # Validate retrieval_output_dir requirement
58 | if file_source == "retrieved_all_files" and retrieval_output_dir is None:
59 | raise ValueError(
60 | "retrieval_output_dir is required when file_source is 'retrieved_all_files'"
61 | )
62 |
63 | @property
64 | def requires_input(self) -> bool:
65 | """Need the input to get problem statement and base commit"""
66 | return True
67 |
68 | @property
69 | def requires_reference(self) -> bool:
70 | """Reference is optional for this evaluator"""
71 | return False
72 |
73 | def _parse_json(self, text: str) -> dict:
74 | """Parse JSON from LLM response."""
75 | # First, try to find JSON string within triple backticks
76 | # Handle both ```json and ``` formats
77 | json_pattern = r"```(?:json)?\s*\n?(.*?)\n?```"
78 | matches = re.findall(json_pattern, text, re.DOTALL | re.MULTILINE)
79 |
80 | for match in matches:
81 | try:
82 | cleaned_match = match.strip()
83 | if cleaned_match:
84 | return json.loads(cleaned_match)
85 | except json.JSONDecodeError:
86 | continue
87 |
88 | # If no backtick-wrapped JSON found, try to extract JSON object directly
89 | # Look for content between { and }
90 | json_object_pattern = r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}"
91 | json_matches = re.findall(json_object_pattern, text, re.DOTALL)
92 |
93 | for match in json_matches:
94 | try:
95 | return json.loads(match)
96 | except json.JSONDecodeError:
97 | continue
98 |
99 | # Last resort: try parsing the entire text
100 | try:
101 | return json.loads(text.strip())
102 | except json.JSONDecodeError:
103 | # If all else fails, provide more detailed error
104 | logger.error(f"Failed to parse JSON from response: {text[:500]}...")
105 | raise ValueError("No valid JSON found in LLM response")
106 |
107 | def evaluate_on_reference_review_comment(
108 | self,
109 | review_comment: ReferenceReviewComment,
110 | input: CodeReviewTaskInstance,
111 | ) -> dict:
112 | """Evaluate a single reference review comment.
113 |
114 | Args:
115 | review_comment: The review comment to evaluate
116 | input: The input task instance
117 |
118 | Returns:
119 | Dictionary with label (0/1) and reason
120 | """
121 | # Get file content based on file_source
122 | relevant_files = self._get_relevant_files(review_comment, input)
123 |
124 | # Format the review comment with context
125 | formatted_review = self._format_review_comment_with_context(
126 | review_comment,
127 | relevant_files,
128 | )
129 |
130 | # Load prompts from YAML template
131 | system_prompt, user_prompt = load_prompt(
132 | "repo_level_evaluation",
133 | problem_statement=input.problem_statement,
134 | patch_to_review=input.commit_to_review.patch_to_review,
135 | formatted_review=formatted_review,
136 | )
137 |
138 | messages = [
139 | {"role": "system", "content": system_prompt},
140 | {"role": "user", "content": user_prompt},
141 | ]
142 |
143 | answer = self.model_client.create_completion(messages)
144 | result = self._parse_json(answer)
145 |
146 | # Validate result
147 | if "label" not in result or result["label"] not in [0, 1]:
148 | raise ValueError(f"Invalid label in result: {result}")
149 | if "reason" not in result:
150 | result["reason"] = "No reason provided"
151 |
152 | return result
153 |
154 | def _get_relevant_files(
155 | self,
156 | review_comment: ReferenceReviewComment,
157 | input: CodeReviewTaskInstance,
158 | ) -> dict[str, str]:
159 | """Get relevant files based on file_source strategy."""
160 | return get_relevant_files(
161 | review_comment=review_comment,
162 | file_source=self.file_source,
163 | repo=input.repo,
164 | base_commit=input.base_commit,
165 | patch_to_review=input.commit_to_review.patch_to_review,
166 | tokens=self.tokens,
167 | retrieval_max_files=self.retrieval_max_files,
168 | retrieval_output_dir=self.retrieval_output_dir,
169 | )
170 |
171 | def _format_review_comment_with_context(
172 | self,
173 | review_comment: ReferenceReviewComment,
174 | relevant_files: dict[str, str],
175 | ) -> str:
176 | """Format review comment with file context."""
177 | # Load and render the template
178 | return load_prompt(
179 | "rm_sample",
180 | relevant_files=relevant_files,
181 | diff_hunk=review_comment.diff_hunk or "",
182 | path=review_comment.path or "",
183 | line=review_comment.line or "",
184 | review_comment=review_comment.text.strip(),
185 | add_line_numbers=True,
186 | )
187 |
188 | def _evaluate(
189 | self,
190 | *,
191 | prediction: CodeReviewPrediction,
192 | reference: Any, # noqa: ARG002
193 | input: CodeReviewTaskInstance,
194 | ) -> dict:
195 | """Evaluate code review prediction by classifying extracted defects.
196 |
197 | Args:
198 | prediction: The code review prediction
199 | reference: Not used in this evaluator
200 | input: The input task instance
201 |
202 | Returns:
203 | Dictionary containing evaluation metrics
204 | """
205 | # Extract defects from the predicted review
206 | predicted_defects = extract_defects_from_review(prediction.review_text, input)
207 |
208 | if not predicted_defects:
209 | # No defects found means it's a positive review
210 | return {
211 | "score": 1.0,
212 | "num_defects": 0,
213 | "classifications": [],
214 | }
215 |
216 | # Classify each defect
217 | classifications = []
218 | positive_count = 0
219 | negative_count = 0
220 |
221 | for defect in predicted_defects:
222 | try:
223 | result = self.evaluate_on_reference_review_comment(defect, input)
224 | classifications.append(
225 | {
226 | "defect_text": defect.text,
227 | "label": result["label"],
228 | "reason": result["reason"],
229 | }
230 | )
231 |
232 | if result["label"] == 1:
233 | positive_count += 1
234 | else:
235 | negative_count += 1
236 |
237 | except Exception as e:
238 | logger.error(f"Failed to classify defect: {e}")
239 | classifications.append(
240 | {
241 | "defect_text": defect.text,
242 | "label": 0, # Default to negative on error
243 | "reason": f"Classification error: {str(e)}",
244 | }
245 | )
246 | negative_count += 1
247 |
248 | # Calculate overall score
249 | total_defects = len(predicted_defects)
250 | if total_defects > 0:
251 | score = positive_count / total_defects
252 | else:
253 | score = 1.0 # No defects means positive
254 |
255 | return {
256 | "score": score,
257 | "num_defects": total_defects,
258 | "num_positive": positive_count,
259 | "num_negative": negative_count,
260 | "classifications": classifications,
261 | }
262 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
--------------------------------------------------------------------------------
/docs/questionnaire-demo.md:
--------------------------------------------------------------------------------
1 | ### Code Review Annotation Questionnaire
2 |
3 | **Instance ID**: `dbt-labs__dbt-core-3100@4f8c10c`
4 | **Repository**: `dbt-labs/dbt-core`
5 | **Pull Request**: [`#3100`](https://github.com/dbt-labs/dbt-core/pull/3100)
6 | **Language**: `Python`
7 | **Created at**: `2021-02-13T16:24:06Z`
8 | **Base commit**: `934c23bf392c4d5db9b129d0f8953c2b7872f5b0`
9 | **Commit to review**: `4f8c10c1aab3630f175611e87903b6051afdd8d8`
10 | **Commit message**:
11 | > default to get_columns_in_relation if not specified in config
12 | >
13 | > Co-authored-by: Jeremy Cohen
14 |
15 | ---
16 |
17 | ### Context
18 |
19 | #### Problem Statement
20 |
21 | > Specify Columns to update on incremental Models
22 | > ### Describe the feature
23 | > Allow model developer the ability to choose the set of columns that require updating on incremental loads
24 | >
25 | > ### Who will this benefit?
26 | > This will benefit all users that are performing an incremental update of a wide table with few columns that are mutable
27 |
28 | #### Hints (pre-PR comments or linked issues context)
29 |
30 | > hey @richmintz! Can you say more about this? I think it's a good idea - I'm curious if you're interested in this feature for performance reasons, or something else.
31 | > Hi @drewbanin!
32 | > My primary goal would be performance, minimizing writes and index rebuilds. However I also think it will generate more concise sql for large tables when a merge is required.
33 | > Got it! Sure, this makes a lot of sense in `merge` statements. Is this on Snowflake/BigQuery? I think a similar paradigm would apply to Redshift/Postgres, but we'd modify the `update` statement instead of the `merge` statement.
34 | >
35 | > I can imagine a config like `update_columns` which could be supplied like:
36 | > ```
37 | > {{
38 | > config(
39 | > materialized='incremental',
40 | > update_columns=['order_total', 'count_orders']
41 | > )
42 | > }}
43 | > ```
44 | >
45 | > If `update_columns` is not provided, or if it is `null`, dbt will default to updating all of the columns in the model.
46 | >
47 | > You buy all of that?
48 | >
49 | >
50 | > As part of this proposed change, it would also be nice to be able to exclude the `when matched then update set` part of the merge altogether, as in some of my models I'm only interested in adding new rows since the source data is never updated (for event-based data for example or other append-only tables), and it makes the model execution faster (at least in BigQuery).
51 | >
52 | > It could be a separate config `no_update = True` or just a convention that doesn't include the SQL block when `update_columns` is empty.
53 | >
54 | > Please note I tested this and would work for BigQuery, other databases might need a different syntax to support no-op for updates.
55 | >
56 | > Any thoughts?
57 | > hey @bfil - if you do not supply a `unique_key` for your incremental model's config, then dbt should not inject a `when matched then update set....` clause to the merge statement.
58 | >
59 | > Check out the implementation here:
60 | > https://github.com/fishtown-analytics/dbt/blob/7a07017b96ac332c872725343833a94b49129c68/core/dbt/include/global_project/macros/materializations/common/merge.sql#L23-L48
61 | > @drewbanin I know, but I still want to compare and insert only the new data based on that unique key rather than merging on an always `FALSE` condition. Makes sense?
62 | > ah! Sure, that makes a lot of sense. What do you think about pushing this code down from the materialization layer and into the modeling layer. Could you do something like:
63 | >
64 | > ```
65 | > {{ config(materialized='incremental') }}
66 | >
67 | > select id, col1, col2 from {{ ref('some_table') }}
68 | > {% if is_incremental() %}
69 | > where id not in (select id from {{ this }})
70 | > {% endif %}
71 | > ```
72 | >
73 | > This is a kind of funny adaptation of the [typical incremental modeling approach](https://docs.getdbt.com/docs/configuring-incremental-models), but it should serve to only select the _new_ records from your source table, inserting them directly into the destination table.
74 | > @drewbanin : any update on where we are with this feature...i have a wide table [Snowflake] but i only need to update very few columns in the incremental model. any ETA's on the feature ?
75 | > hey @ee07dazn - no movement on this issue on our end to report. I think the spec I mentioned above (pasted below) is still a good idea:
76 | >
77 | > ```
78 | > {{
79 | > config(
80 | > materialized='incremental',
81 | > update_columns=['order_total', 'count_orders']
82 | > )
83 | > }
84 | > ```
85 | >
86 | > I think that this feature should only be supported on databases that support `merge` statements, as dbt's delete+insert approach doesn't really lend itself well to this approach. If anyone is interested in working on this one, I'd be happy to try to point you in the right direction.
87 | >
88 | > The key change here will be to replace the call to [`adapter.get_columns_in_relation`](https://github.com/fishtown-analytics/dbt/blob/f3565f3f70062c5818395d03a52c89e87616f956/plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql#L59) with `config.get('update_columns')`. We'll want to implement this for both the Snowflake and BigQuery incremental materializations when the `merge` strategy is used.
89 | > Thanks @drewbanin for the information and a pointer. Apart from what you had suggested, i had to make change to the default__get_merge_sql macro to make that approach work. It works now but i am just not happy to make a change to something as low level as default__get_merge_sql. Probably i think i can update the snowflake_get_merge_sql to do this job. Thinking out loud...but thanks for the help.
90 |
91 | #### Resolved Issues
92 |
93 |
94 |
95 |
96 | ##### Issue #1862
97 |
98 | - Title:
99 |
100 | > Specify Columns to update on incremental Models
101 |
102 | - Body:
103 |
104 | > ### Describe the feature
105 | > Allow model developer the ability to choose the set of columns that require updating on incremental loads
106 | >
107 | > ### Who will this benefit?
108 | > This will benefit all users that are performing an incremental update of a wide table with few columns that are mutable
109 |
110 |
111 |
112 | #### Patch to Review (diff)
113 |
114 | ```diff
115 | diff --git a/plugins/bigquery/dbt/include/bigquery/macros/materializations/incremental.sql b/plugins/bigquery/dbt/include/bigquery/macros/materializations/incremental.sql
116 | index f4ad80d5a5c..a48178f88a3 100644
117 | --- a/plugins/bigquery/dbt/include/bigquery/macros/materializations/incremental.sql
118 | +++ b/plugins/bigquery/dbt/include/bigquery/macros/materializations/incremental.sql
119 | @@ -112,7 +112,10 @@
120 | {% endif %}
121 | {% set build_sql = create_table_as(False, target_relation, sql) %}
122 | {% else %}
123 | - {% set dest_columns = adapter.get_columns_in_relation(existing_relation) %}
124 | + {% set dest_columns = config.get('update_columns', none) %}
125 | + {% if dest_columns is none %}
126 | + {% set dest_columns = adapter.get_columns_in_relation(existing_relation) %}
127 | + {% endif %}
128 |
129 | {#-- if partitioned, use BQ scripting to get the range of partition values to be updated --#}
130 | {% if strategy == 'insert_overwrite' %}
131 | diff --git a/plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql b/plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql
132 | index 188795b1537..6d0d3bf9805 100644
133 | --- a/plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql
134 | +++ b/plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql
135 | @@ -58,7 +58,7 @@
136 | {% do adapter.expand_target_column_types(
137 | from_relation=tmp_relation,
138 | to_relation=target_relation) %}
139 | - {% set dest_columns = adapter.get_columns_in_relation(target_relation) %}
140 | + {% set dest_columns = config.get('update_columns') %}
141 | {% set build_sql = dbt_snowflake_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key, dest_columns) %}
142 | {% endif %}
143 |
144 |
145 | ```
146 |
147 | #### Reference Review Comments (from the PR)
148 |
149 | Count: 1
150 |
151 |
152 | ---
153 |
154 | ##### Comment 1
155 |
156 | - Review comment:
157 |
158 | > @user1:
159 | > As you did for BigQuery, this should be:
160 | > ```suggestion
161 | > {% set dest_columns = config.get('update_columns', none) %}
162 | > {% if dest_columns is none %}
163 | > {% set dest_columns = adapter.get_columns_in_relation(existing_relation) %}
164 | > {% endif %}
165 | > ```
166 | > I believe this is what's causing the failing integration test:
167 | > ```
168 | > 'NoneType' object is not iterable
169 | > ```
170 | >
171 | > @author:
172 | > of course 🤦
173 |
174 | - Diff hunk the review comment applied to:
175 |
176 | ```diff
177 | @@ -58,7 +58,7 @@
178 | {% do adapter.expand_target_column_types(
179 | from_relation=tmp_relation,
180 | to_relation=target_relation) %}
181 | - {% set dest_columns = adapter.get_columns_in_relation(target_relation) %}
182 | + {% set dest_columns = config.get('update_columns') %}
183 | ```
184 |
185 | - Path: `plugins/snowflake/dbt/include/snowflake/macros/materializations/incremental.sql`
186 | - Line: n/a | Start line: n/a
187 | - Original line: 61 | Original start line: n/a
188 |
189 |
190 | ---
191 |
192 | ### Section 1 — Problem Statement and Patch Alignment
193 |
194 | - **Q1.1 — Is the problem statement sufficiently well-specified for reviewing this patch?**
195 |
196 | _Your notes:_
197 |
198 | - **Q1.2 — Explain your choice above. Reference specific files/functions when relevant.**
199 |
200 | _Your notes:_
201 |
202 | - **Q1.3 — Does the patch address the stated problem?**
203 | - [ ] Yes
204 | - [ ] Partially
205 | - [ ] No
206 |
207 | - **Why?**
208 |
209 | ---
210 |
211 | ### Section 2 — Review Scope and Comment Coverage
212 |
213 | - **Q2.1 — Do the provided reference review comments appropriately cover the patch (scope and correctness)?**
214 |
215 | _Your notes:_
216 |
217 | - **Q2.2 — Explain your choice. Identify any mis-scoped or missing review topics.**
218 |
219 | _Your notes:_
220 |
221 | - **Q2.3 — Reference Review Comments Assessment**
222 |
223 | For each reference comment, mark whether it is a true positive (valid issue), false positive (not actually an issue), or informational. Briefly justify.
224 |
225 | | Index | Category (Functionality/Correctness/Performance/Security/Maintainability/Style/Docs) | Verdict (TP/FP/Info) | Notes |
226 | | --- | --- | --- | --- |
227 | | 1 | | | |
228 |
229 | - **Q2.4 — Missing Review Points**
230 |
231 | List important review findings not covered by the reference comments.
232 |
233 | - [ ] Item 1:
234 | - [ ] Item 2:
235 | - [ ] Item 3:
236 |
237 | ---
238 |
239 | ### Section 3 — Defects Identified in the Patch
240 |
241 | Repeat the following block for each defect you identify.
242 |
243 | ```text
244 | Defect N
245 | - Category: [ ] Functionality [ ] Correctness [ ] Performance [ ] Security [ ] Maintainability [ ] Style [ ] Documentation
246 | - Severity (1-5):
247 | - Files/Locations:
248 | - Short description:
249 | - Suggested fix (optional):
250 | ```
251 |
252 | ---
253 |
254 | ### Section 4 — Difficulty and Review Effort
255 |
256 | - **Q4.1 — Difficulty to understand and correctly review the patch**
257 | - [ ] <15 min
258 | - [ ] 15 min - 1 hour
259 | - [ ] 1-4 hours
260 | - [ ] >4 hours
261 |
262 | - **Why?**
263 |
264 | - Time estimate:
265 |
266 | - **Q4.2 — Estimated review effort (1-5)**
267 | - [ ] 1: Trivial, almost no cognitive load
268 | - [ ] 2: Small effort, localized changes
269 | - [ ] 3: Moderate effort, multiple files/concerns
270 | - [ ] 4: High effort, complex interactions or risk
271 | - [ ] 5: Very high effort, wide impact or intricate logic
272 |
273 | - **Rationale:**
274 |
275 | - Effort level:
276 |
277 | ---
278 |
279 | ### Section 5 — Overall Patch Quality and Risk
280 |
281 | - **Q5.1 — Does the patch meet acceptance criteria for the stated problem?**
282 | - [ ] Yes
283 | - [ ] Yes, with minor issues
284 | - [ ] No
285 |
286 | - **Notes:**
287 |
288 | - **Q5.2 — Regression risk assessment**
289 | - [ ] Low
290 | - [ ] Medium
291 | - [ ] High
292 |
293 | - **Why?**
294 |
295 | - Risk level:
296 |
297 | - **Q5.3 — Additional observations (tests, docs, API changes, compatibility)**
298 |
299 | _Your notes:_
300 |
301 | ---
302 |
303 | ### Section 6 — Dataset Suitability
304 |
305 | - **Q6.1 — Is this instance suitable for evaluating code review quality?**
306 | - [ ] Yes
307 | - [ ] No
308 |
309 | - If 'No', explain (e.g., ambiguous problem, non-diff artifacts, excessive binary/auto-generated changes, missing context):
310 |
311 | - **Q6.2 — Any compliance, privacy, or licensing concerns?**
312 | - [ ] No
313 | - [ ] Yes
314 |
315 | - If 'Yes' explain:
316 |
317 | ---
318 |
319 | ### Section 7 — Confidence
320 |
321 | - **Q7.1 — How confident are you in your annotations? (1 = very low, 5 = very high)**
322 |
323 | Select one: [ ] 1 [ ] 2 [ ] 3 [ ] 4 [ ] 5
--------------------------------------------------------------------------------
/src/swe_care/utils/llm_models/clients.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from abc import ABC, abstractmethod
4 | from typing import Any
5 |
6 | from loguru import logger
7 |
8 | try:
9 | import openai
10 | except ImportError:
11 | raise ImportError(
12 | "OpenAI package not found. Please install it with: pip install openai"
13 | )
14 |
15 | try:
16 | import anthropic
17 | except ImportError:
18 | raise ImportError(
19 | "Anthropic package not found. Please install it with: pip install anthropic"
20 | )
21 |
22 | try:
23 | import tiktoken
24 | except ImportError:
25 | raise ImportError(
26 | "Tiktoken package not found. Please install it with: pip install tiktoken"
27 | )
28 |
29 | DEFAULT_MAX_RETRIES = 10
30 |
31 |
32 | class BaseModelClient(ABC):
33 | """Abstract base class for LLM model clients."""
34 |
35 | client: Any
36 | model: str
37 | model_provider: str
38 | model_kwargs: dict[str, Any]
39 | max_retries: int
40 |
41 | def __init__(
42 | self,
43 | model: str,
44 | model_provider: str,
45 | max_retries: int = DEFAULT_MAX_RETRIES,
46 | **model_kwargs: Any,
47 | ):
48 | self.model = model
49 | self.model_provider = model_provider
50 | self.model_kwargs = model_kwargs
51 | self.max_retries = max_retries
52 |
53 | @abstractmethod
54 | def create_completion(self, messages: list[dict[str, str]]) -> str:
55 | """Create a completion using the LLM API.
56 |
57 | Args:
58 | messages: List of messages in OpenAI format [{"role": "user", "content": "..."}]
59 |
60 | Returns:
61 | The generated completion text
62 | """
63 | pass
64 |
65 | @abstractmethod
66 | def create_completion_with_structured_output(
67 | self, messages: list[dict[str, str]], json_schema: dict
68 | ) -> dict:
69 | """Create a completion with structured output using the LLM API.
70 |
71 | Args:
72 | messages: List of messages in OpenAI format [{"role": "user", "content": "..."}]
73 | json_schema: JSON Schema that defines the expected output structure
74 |
75 | Returns:
76 | The generated completion as a dictionary matching the schema
77 | """
78 | pass
79 |
80 | @abstractmethod
81 | def count_tokens_from_text(self, text: str) -> int:
82 | """Count the number of tokens in the text."""
83 | pass
84 |
85 | @abstractmethod
86 | def count_tokens_from_messages(self, messages: list[dict[str, str]]) -> int:
87 | """Count the number of tokens in the messages."""
88 | pass
89 |
90 |
91 | class OpenAIClient(BaseModelClient):
92 | """OpenAI API client."""
93 |
94 | def __init__(
95 | self,
96 | model: str,
97 | model_provider: str,
98 | max_retries: int = DEFAULT_MAX_RETRIES,
99 | **model_kwargs: Any,
100 | ):
101 | super().__init__(model, model_provider, max_retries, **model_kwargs)
102 |
103 | # Initialize the OpenAI client
104 | api_key = os.getenv("OPENAI_API_KEY")
105 | if not api_key:
106 | raise ValueError("OPENAI_API_KEY environment variable not set")
107 |
108 | self.client: openai.OpenAI = openai.OpenAI(
109 | api_key=api_key, max_retries=self.max_retries
110 | )
111 |
112 | def create_completion(self, messages: list[dict[str, str]]) -> str:
113 | """Create a completion using OpenAI API."""
114 | try:
115 | response = self.client.chat.completions.create(
116 | model=self.model, messages=messages, **self.model_kwargs
117 | )
118 | return response.choices[0].message.content
119 | except Exception as e:
120 | logger.error(f"Error creating OpenAI completion: {e}")
121 | raise e
122 |
123 | def create_completion_with_structured_output(
124 | self, messages: list[dict[str, str]], json_schema: dict, strict: bool = False
125 | ) -> dict:
126 | """Create a completion with structured output using OpenAI API."""
127 | try:
128 | response = self.client.chat.completions.create(
129 | model=self.model,
130 | messages=messages,
131 | response_format={
132 | "type": "json_schema",
133 | "json_schema": {
134 | "name": json_schema.get("name", "structured_response"),
135 | "strict": strict,
136 | "schema": json_schema,
137 | },
138 | },
139 | **self.model_kwargs,
140 | )
141 | return json.loads(response.choices[0].message.content)
142 | except Exception as e:
143 | logger.error(f"Error creating OpenAI structured completion: {e}")
144 | raise e
145 |
146 | def count_tokens_from_text(self, text: str) -> int:
147 | """Count the number of tokens in the text using tiktoken."""
148 | try:
149 | encoding = tiktoken.encoding_for_model(self.model)
150 | except KeyError:
151 | # Fall back to o200k_base encoding if model not found
152 | encoding = tiktoken.get_encoding("o200k_base")
153 |
154 | return len(encoding.encode(text, disallowed_special=()))
155 |
156 | def count_tokens_from_messages(self, messages: list[dict[str, str]]) -> int:
157 | """Count the number of tokens in the messages using tiktoken."""
158 | tokens_per_message = 3
159 | tokens_per_name = 1
160 |
161 | num_tokens = 0
162 | for message in messages:
163 | num_tokens += tokens_per_message
164 | for key, value in message.items():
165 | num_tokens += self.count_tokens_from_text(value)
166 | if key == "name":
167 | num_tokens += tokens_per_name
168 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
169 | return num_tokens
170 |
171 |
172 | class DeepSeekClient(OpenAIClient):
173 | """DeepSeek API client."""
174 |
175 | def __init__(
176 | self,
177 | model: str,
178 | model_provider: str,
179 | max_retries: int = DEFAULT_MAX_RETRIES,
180 | **model_kwargs: Any,
181 | ):
182 | super().__init__(model, model_provider, max_retries, **model_kwargs)
183 |
184 | self.client = openai.OpenAI.copy(
185 | self.client, base_url="https://api.deepseek.com/v1"
186 | )
187 |
188 |
189 | class QwenClient(OpenAIClient):
190 | """Qwen API client."""
191 |
192 | def __init__(self, model: str, model_provider: str, **model_kwargs: Any):
193 | # Handle enable_thinking
194 | if "enable_thinking" in model_kwargs:
195 | enable_thinking = model_kwargs.pop("enable_thinking")
196 | model_kwargs["extra_body"] = {"enable_thinking": enable_thinking}
197 | else:
198 | model_kwargs["extra_body"] = {"enable_thinking": False}
199 |
200 | super().__init__(model, model_provider, **model_kwargs)
201 |
202 | self.client = openai.OpenAI.copy(
203 | self.client, base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
204 | )
205 |
206 | def create_completion(self, messages: list[dict[str, str]]) -> str:
207 | """Create a completion using Qwen API with streaming support for enable_thinking."""
208 | # If enable_thinking is True, we need to use streaming
209 | if self.model_kwargs.get("extra_body", {}).get("enable_thinking", False):
210 | response = self.client.chat.completions.create(
211 | model=self.model,
212 | messages=messages,
213 | stream=True,
214 | **self.model_kwargs,
215 | )
216 |
217 | # Collect the streamed response
218 | content = ""
219 | for chunk in response:
220 | if chunk.choices[0].delta.content is not None:
221 | content += chunk.choices[0].delta.content
222 |
223 | return content
224 | else:
225 | # Use parent implementation for non-thinking calls
226 | return super().create_completion(messages)
227 |
228 |
229 | class MoonshotClient(OpenAIClient):
230 | """DeepSeek API client."""
231 |
232 | def __init__(
233 | self,
234 | model: str,
235 | model_provider: str,
236 | max_retries: int = DEFAULT_MAX_RETRIES,
237 | **model_kwargs: Any,
238 | ):
239 | super().__init__(model, model_provider, max_retries, **model_kwargs)
240 |
241 | self.client = openai.OpenAI.copy(
242 | self.client, base_url="https://api.moonshot.cn/v1"
243 | )
244 |
245 |
246 | class GeminiClient(OpenAIClient):
247 | """Gemini API client using Google's OpenAI-compatible API."""
248 |
249 | def __init__(
250 | self,
251 | model: str,
252 | model_provider: str,
253 | max_retries: int = DEFAULT_MAX_RETRIES,
254 | **model_kwargs: Any,
255 | ):
256 | super().__init__(model, model_provider, max_retries, **model_kwargs)
257 |
258 | self.client = openai.OpenAI.copy(
259 | self.client,
260 | base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
261 | )
262 |
263 |
264 | class AnthropicClient(BaseModelClient):
265 | """Anthropic API client."""
266 |
267 | def __init__(
268 | self,
269 | model: str,
270 | model_provider: str,
271 | max_retries: int = DEFAULT_MAX_RETRIES,
272 | **model_kwargs: Any,
273 | ):
274 | super().__init__(model, model_provider, max_retries, **model_kwargs)
275 |
276 | # Initialize the Anthropic client
277 | api_key = os.getenv("ANTHROPIC_API_KEY")
278 | if not api_key:
279 | raise ValueError("ANTHROPIC_API_KEY environment variable not set")
280 |
281 | self.client: anthropic.Anthropic = anthropic.Anthropic(
282 | api_key=api_key, max_retries=self.max_retries
283 | )
284 |
285 | def _convert_to_anthropic_format(
286 | self, messages: list[dict[str, str]]
287 | ) -> tuple[list[dict[str, str]], str | None]:
288 | """Convert messages from OpenAI format to Anthropic format.
289 |
290 | Args:
291 | messages: List of messages in OpenAI format
292 |
293 | Returns:
294 | Tuple of (anthropic_messages, system_message)
295 | """
296 | system_message = None
297 | anthropic_messages = []
298 |
299 | # Extract system message if present
300 | if messages and messages[0]["role"] == "system":
301 | system_message = messages[0]["content"]
302 | messages = messages[1:]
303 |
304 | # Format remaining messages
305 | for msg in messages:
306 | anthropic_messages.append({"role": msg["role"], "content": msg["content"]})
307 |
308 | return anthropic_messages, system_message
309 |
310 | def create_completion(self, messages: list[dict[str, str]]) -> str:
311 | """Create a completion using Anthropic API."""
312 | try:
313 | # Convert OpenAI format to Anthropic format
314 | anthropic_messages, system_message = self._convert_to_anthropic_format(
315 | messages
316 | )
317 |
318 | kwargs = self.model_kwargs.copy()
319 | if system_message:
320 | kwargs["system"] = system_message
321 |
322 | response = self.client.messages.create(
323 | model=self.model,
324 | messages=anthropic_messages,
325 | max_tokens=kwargs.pop("max_tokens", 4096),
326 | **kwargs,
327 | )
328 | return response.content[0].text
329 | except Exception as e:
330 | logger.error(f"Error creating Anthropic completion: {e}")
331 | raise e
332 |
333 | def create_completion_with_structured_output(
334 | self, messages: list[dict[str, str]], json_schema: dict
335 | ) -> dict:
336 | """Create a completion with structured output using Anthropic API."""
337 | try:
338 | # Convert OpenAI format to Anthropic format
339 | anthropic_messages, system_message = self._convert_to_anthropic_format(
340 | messages
341 | )
342 |
343 | kwargs = self.model_kwargs.copy()
344 | if system_message:
345 | kwargs["system"] = system_message
346 |
347 | # Create a tool definition from the JSON schema
348 | tool_name = json_schema.get("name", "record_output")
349 | tool = {
350 | "name": tool_name,
351 | "description": f"Record output using the schema: {json_schema.get('description', 'Structured output')}",
352 | "input_schema": json_schema,
353 | }
354 |
355 | response = self.client.messages.create(
356 | model=self.model,
357 | messages=anthropic_messages,
358 | max_tokens=kwargs.pop("max_tokens", 4096),
359 | tools=[tool],
360 | tool_choice={"type": "tool", "name": tool_name},
361 | **kwargs,
362 | )
363 |
364 | # Extract the tool use result
365 | for content in response.content:
366 | if content.type == "tool_use" and content.name == tool_name:
367 | return content.input
368 |
369 | raise ValueError("No tool use found in response")
370 | except Exception as e:
371 | logger.error(f"Error creating Anthropic structured completion: {e}")
372 | raise e
373 |
374 | def count_tokens_from_text(self, text: str) -> int:
375 | """Count the number of tokens in the text using Anthropic's API."""
376 | try:
377 | # Wrap the text in a user message for token counting
378 | result = self.client.messages.count_tokens(
379 | model=self.model, messages=[{"role": "user", "content": text}]
380 | )
381 | return result.usage.input_tokens
382 | except Exception as e:
383 | logger.error(f"Error counting tokens with Anthropic API: {e}")
384 | raise e
385 |
386 | def count_tokens_from_messages(self, messages: list[dict[str, str]]) -> int:
387 | """Count the number of tokens in the messages using Anthropic's API."""
388 | try:
389 | # Convert OpenAI format to Anthropic format
390 | anthropic_messages, system_message = self._convert_to_anthropic_format(
391 | messages
392 | )
393 |
394 | # Count tokens with Anthropic API
395 | kwargs = {}
396 | if system_message:
397 | kwargs["system"] = system_message
398 |
399 | result = self.client.messages.count_tokens(
400 | model=self.model, messages=anthropic_messages, **kwargs
401 | )
402 | return result.usage.input_tokens
403 | except Exception as e:
404 | logger.error(f"Error counting tokens with Anthropic API: {e}")
405 | raise e
406 |
--------------------------------------------------------------------------------
/src/swe_care/collect/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from pathlib import Path
4 |
5 | from loguru import logger
6 |
7 | import swe_care.collect.build_code_review_dataset
8 | import swe_care.collect.classify_prs_data
9 | import swe_care.collect.convert_to_rm_samples
10 | import swe_care.collect.get_graphql_prs_data
11 | import swe_care.collect.get_top_repos
12 | from swe_care.collect.build_code_review_dataset import build_code_review_dataset
13 | from swe_care.collect.classify_prs_data import classify_prs_data
14 | from swe_care.collect.convert_to_rm_samples import convert_to_rm_samples
15 | from swe_care.collect.get_graphql_prs_data import get_graphql_prs_data
16 | from swe_care.collect.get_top_repos import get_top_repos
17 |
18 | # Mapping of subcommands to their function names
19 | SUBCOMMAND_MAP = {
20 | "get_top_repos": {
21 | "function": get_top_repos,
22 | "help": swe_care.collect.get_top_repos.__doc__,
23 | },
24 | "get_graphql_prs_data": {
25 | "function": get_graphql_prs_data,
26 | "help": swe_care.collect.get_graphql_prs_data.__doc__,
27 | },
28 | "classify_prs_data": {
29 | "function": classify_prs_data,
30 | "help": swe_care.collect.classify_prs_data.__doc__,
31 | },
32 | "build_code_review_dataset": {
33 | "function": build_code_review_dataset,
34 | "help": swe_care.collect.build_code_review_dataset.__doc__,
35 | },
36 | "convert_to_rm_samples": {
37 | "function": convert_to_rm_samples,
38 | "help": swe_care.collect.convert_to_rm_samples.__doc__,
39 | },
40 | }
41 |
42 |
43 | def create_global_parser():
44 | """Create a parser with global arguments that can be used as a parent parser."""
45 | global_parser = argparse.ArgumentParser(add_help=False)
46 | global_parser.add_argument(
47 | "--tokens",
48 | type=str,
49 | nargs="*",
50 | default=None,
51 | help="GitHub API token(s) to be used randomly for fetching data",
52 | )
53 | global_parser.add_argument(
54 | "--output-dir",
55 | type=Path,
56 | required=True,
57 | help="Path to output directory",
58 | )
59 | return global_parser
60 |
61 |
62 | def get_args():
63 | # Parse command line manually to handle flexible argument order
64 | args = sys.argv[1:]
65 |
66 | # Find the subcommand
67 | subcommands = list(SUBCOMMAND_MAP.keys())
68 | subcommand = None
69 | subcommand_index = None
70 |
71 | for i, arg in enumerate(args):
72 | if arg in subcommands:
73 | subcommand = arg
74 | subcommand_index = i
75 | break
76 |
77 | # Create global parser
78 | global_parser = create_global_parser()
79 |
80 | if subcommand is None:
81 | # No subcommand found, use normal argparse
82 | parser = argparse.ArgumentParser(
83 | prog="swe_care.collect",
84 | description="Data collection tools for SWE-CARE",
85 | parents=[global_parser],
86 | )
87 |
88 | subparsers = parser.add_subparsers(dest="command", help="Available commands")
89 | for cmd, info in SUBCOMMAND_MAP.items():
90 | subparsers.add_parser(cmd, help=info["help"])
91 |
92 | return parser.parse_args(args)
93 |
94 | # Create the appropriate subcommand parser with global parser as parent
95 | match subcommand:
96 | case "get_top_repos":
97 | sub_parser = argparse.ArgumentParser(
98 | prog=f"swe_care.collect {subcommand}",
99 | parents=[global_parser],
100 | description=SUBCOMMAND_MAP[subcommand]["help"],
101 | )
102 | sub_parser.add_argument(
103 | "--language",
104 | type=str,
105 | required=True,
106 | help="Programming language to search for",
107 | )
108 | sub_parser.add_argument(
109 | "--top-n",
110 | type=int,
111 | required=True,
112 | help="Number of top repositories to fetch",
113 | )
114 |
115 | case "get_graphql_prs_data":
116 | sub_parser = argparse.ArgumentParser(
117 | prog=f"swe_care.collect {subcommand}",
118 | parents=[global_parser],
119 | description=SUBCOMMAND_MAP[subcommand]["help"],
120 | )
121 | repo_group = sub_parser.add_mutually_exclusive_group(required=True)
122 | repo_group.add_argument(
123 | "--repo-file",
124 | type=Path,
125 | help="Path to repository JSONL file containing repository information, each line should be a JSON object with at least a 'name' field in format 'owner/repo'. Optionally, include 'pr_cursor' field to resume fetching from a specific cursor for each repository.",
126 | default=None,
127 | )
128 | repo_group.add_argument(
129 | "--repo",
130 | type=str,
131 | help="Repository in format 'owner/repo'",
132 | default=None,
133 | )
134 | sub_parser.add_argument(
135 | "--max-number",
136 | type=int,
137 | default=None,
138 | help="Maximum number of PRs to fetch per page (ignored when specific_prs is provided). If not provided, all PRs will be fetched.",
139 | )
140 | sub_parser.add_argument(
141 | "--jobs",
142 | type=int,
143 | default=2,
144 | help="Number of concurrent jobs/threads to use (default: 2)",
145 | )
146 | sub_parser.add_argument(
147 | "--specific-prs",
148 | type=int,
149 | nargs="*",
150 | default=None,
151 | help="Specific PR numbers to fetch (if not specified, fetches all PRs with closing issues)",
152 | )
153 | sub_parser.add_argument(
154 | "--after-pr-cursor",
155 | type=str,
156 | default=None,
157 | help="Resume fetching PRs after this cursor (useful for resuming interrupted runs). When used with --repo-file, acts as fallback for repositories without pr_cursor field in the file.",
158 | )
159 |
160 | case "classify_prs_data":
161 | sub_parser = argparse.ArgumentParser(
162 | prog=f"swe_care.collect {subcommand}",
163 | parents=[global_parser],
164 | description=SUBCOMMAND_MAP[subcommand]["help"],
165 | )
166 | sub_parser.add_argument(
167 | "--graphql-prs-data-file",
168 | type=Path,
169 | required=True,
170 | help="Path to GraphQL PRs data file or directory containing *_graphql_prs_data.jsonl files",
171 | )
172 | sub_parser.add_argument(
173 | "--jobs",
174 | type=int,
175 | default=2,
176 | help="Number of concurrent jobs/threads to use (default: 2)",
177 | )
178 |
179 | case "build_code_review_dataset":
180 | sub_parser = argparse.ArgumentParser(
181 | prog=f"swe_care.collect {subcommand}",
182 | parents=[global_parser],
183 | description=SUBCOMMAND_MAP[subcommand]["help"],
184 | )
185 | sub_parser.add_argument(
186 | "--graphql-prs-data-file",
187 | type=Path,
188 | required=True,
189 | help="Path to GraphQL PRs data file or directory containing *_graphql_prs_data.jsonl files",
190 | )
191 | sub_parser.add_argument(
192 | "--pr-classification-file",
193 | type=Path,
194 | required=True,
195 | help="Path to PR classification file or directory containing *_pr_classification.jsonl files",
196 | )
197 | sub_parser.add_argument(
198 | "--model",
199 | type=str,
200 | required=True,
201 | help="Model name for metadata classification (e.g., gpt-4o, claude-3-5-sonnet-20241022)",
202 | )
203 | sub_parser.add_argument(
204 | "--model-provider",
205 | type=str,
206 | required=True,
207 | help="Model provider for metadata classification (e.g., openai, anthropic)",
208 | )
209 | sub_parser.add_argument(
210 | "--model-args",
211 | type=str,
212 | default=None,
213 | help="Comma-separated model arguments for metadata classification (e.g., temperature=0.7,top_p=0.9)",
214 | )
215 | sub_parser.add_argument(
216 | "--skip-existing",
217 | action="store_true",
218 | default=False,
219 | help="Skip processing existing instance_id in the output file (default: False)",
220 | )
221 | sub_parser.add_argument(
222 | "--jobs",
223 | type=int,
224 | default=2,
225 | help="Number of concurrent jobs/threads to use (default: 2)",
226 | )
227 | sub_parser.add_argument(
228 | "--keep-empty-reference-review-comments",
229 | action="store_true",
230 | default=False,
231 | help="Keep dataset with empty reference review comments list (default: False)",
232 | )
233 |
234 | case "convert_to_rm_samples":
235 | sub_parser = argparse.ArgumentParser(
236 | prog=f"swe_care.collect {subcommand}",
237 | parents=[global_parser],
238 | description=SUBCOMMAND_MAP[subcommand]["help"],
239 | )
240 | sub_parser.add_argument(
241 | "--graphql-prs-data-file",
242 | type=Path,
243 | required=True,
244 | help="Path to GraphQL PRs data file or directory containing *_graphql_prs_data.jsonl files",
245 | )
246 | sub_parser.add_argument(
247 | "--pr-classification-file",
248 | type=Path,
249 | required=True,
250 | help="Path to PR classification file or directory containing *_pr_classification.jsonl files",
251 | )
252 | sub_parser.add_argument(
253 | "--file-source",
254 | type=str,
255 | choices=[
256 | "none",
257 | "base_changed_files",
258 | "reviewed_file",
259 | "retrieved_base_changed_files",
260 | "retrieved_all_files",
261 | ],
262 | default="none",
263 | help="Source for file content in review samples. Choices: 'none' (no file content in review samples), 'base_changed_files' (include changed files contents between base commit and commit to review), 'reviewed_file' (include changed file content to the sample the review comment applied to), 'retrieved_base_changed_files' (use BM25 to retrieve relevant files from changed files based on diff_hunk), 'retrieved_all_files' (use BM25 to retrieve relevant files from entire repository based on diff_hunk). Default: none",
264 | )
265 | sub_parser.add_argument(
266 | "--jobs",
267 | type=int,
268 | default=2,
269 | help="Number of concurrent jobs/threads to use (default: 2)",
270 | )
271 | sub_parser.add_argument(
272 | "--retrieval-max-files",
273 | type=int,
274 | default=5,
275 | help="Maximum number of files to use for retrieval when file_source is 'retrieved_base_changed_files' or 'retrieved_all_files' (default: 5)",
276 | )
277 | sub_parser.add_argument(
278 | "--retrieval-output-dir",
279 | type=Path,
280 | help="Output directory for retrieval operations when file_source is 'retrieved_all_files' (required when file_source is 'retrieved_all_files')",
281 | )
282 | sub_parser.add_argument(
283 | "--skip-existing",
284 | action="store_true",
285 | default=False,
286 | help="Skip processing existing PR (identified by PR number) in existing repo",
287 | )
288 |
289 | # Parse all arguments with the subcommand parser
290 | # This will include both global and subcommand-specific arguments
291 | # Remove the subcommand itself from args
292 | args_without_subcommand = args[:subcommand_index] + args[subcommand_index + 1 :]
293 | final_namespace = sub_parser.parse_args(args_without_subcommand)
294 | final_namespace.command = subcommand
295 |
296 | return final_namespace
297 |
298 |
299 | def main():
300 | args = get_args()
301 |
302 | if args.command in SUBCOMMAND_MAP:
303 | # Get the function from the mapping
304 | cmd_info = SUBCOMMAND_MAP[args.command]
305 | function = cmd_info["function"]
306 |
307 | # Prepare common arguments
308 | common_kwargs = {"output_dir": args.output_dir, "tokens": args.tokens}
309 |
310 | # Add specific arguments based on subcommand
311 | match args.command:
312 | case "get_top_repos":
313 | function(language=args.language, top_n=args.top_n, **common_kwargs)
314 | case "get_graphql_prs_data":
315 | function(
316 | repo_file=args.repo_file,
317 | repo=args.repo,
318 | max_number=args.max_number,
319 | specific_prs=args.specific_prs,
320 | jobs=args.jobs,
321 | after_pr_cursor=args.after_pr_cursor,
322 | **common_kwargs,
323 | )
324 | case "classify_prs_data":
325 | function(
326 | graphql_prs_data_file=args.graphql_prs_data_file,
327 | jobs=args.jobs,
328 | **common_kwargs,
329 | )
330 | case "build_code_review_dataset":
331 | function(
332 | graphql_prs_data_file=args.graphql_prs_data_file,
333 | pr_classification_file=args.pr_classification_file,
334 | model=args.model,
335 | model_provider=args.model_provider,
336 | model_args=args.model_args,
337 | skip_existing=args.skip_existing,
338 | jobs=args.jobs,
339 | keep_empty_reference_review_comments=args.keep_empty_reference_review_comments,
340 | **common_kwargs,
341 | )
342 | case "convert_to_rm_samples":
343 | function(
344 | graphql_prs_data_file=args.graphql_prs_data_file,
345 | pr_classification_file=args.pr_classification_file,
346 | file_source=args.file_source,
347 | jobs=args.jobs,
348 | retrieval_max_files=args.retrieval_max_files,
349 | retrieval_output_dir=args.retrieval_output_dir,
350 | skip_existing=args.skip_existing,
351 | **common_kwargs,
352 | )
353 | else:
354 | logger.info("Please specify a command. Use --help for available commands.")
355 |
356 |
357 | if __name__ == "__main__":
358 | main()
359 |
--------------------------------------------------------------------------------