├── src ├── __init__.py ├── requirements.txt ├── smartfix │ ├── shared │ │ ├── __init__.py │ │ ├── llm_providers.py │ │ ├── coding_agents.py │ │ └── failure_categories.py │ ├── domains │ │ ├── analysis │ │ │ └── __init__.py │ │ ├── workflow │ │ │ ├── __init__.py │ │ │ ├── formatter.py │ │ │ ├── build_runner.py │ │ │ ├── credit_tracking.py │ │ │ └── session_handler.py │ │ ├── integrations │ │ │ └── __init__.py │ │ ├── scm │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── agents │ │ │ ├── __init__.py │ │ │ ├── coding_agent.py │ │ │ ├── agent_factory.py │ │ │ └── agent_session.py │ │ ├── vulnerability │ │ │ ├── __init__.py │ │ │ └── models.py │ │ └── providers │ │ │ └── __init__.py │ ├── telemetry │ │ └── __init__.py │ ├── config │ │ └── __init__.py │ ├── extensions │ │ └── __init__.py │ └── __init__.py ├── github │ ├── __init__.py │ ├── github_scm_provider.py │ ├── github_api_client.py │ └── agent_factory.py ├── build_output_analyzer.py ├── merge_handler.py └── version_check.py ├── test ├── __init__.py ├── conftest.py ├── setup_test_env.py ├── test_setup_test_env.py ├── run_tests.sh ├── test_utils_normalize_host.py ├── test.py ├── test_vulnerability_models.py ├── test_utils_error_exit.py ├── test_aws_bearer_token_bedrock.py ├── test_contrast_llm_config.py ├── test_smartfix_llm_agent.py ├── test_contrast_api_failures.py ├── test_config_integration.py ├── test_main.py ├── test_contrast_message_handling.py ├── test_telemetry_attributes.py ├── test_session_handler.py └── test_version_check.py ├── .beads ├── metadata.json └── issues.jsonl ├── Makefile ├── security.md ├── .flake8 ├── .github ├── workflows │ └── build.yml └── copilot-instructions.md ├── hooks ├── pre-commit └── pre-push ├── setup-hooks.sh ├── .gitignore ├── README.md ├── CONTRIBUTING.md └── docs └── contrast-ai-smartfix.yml.template /src/__init__.py: -------------------------------------------------------------------------------- 1 | # This file makes the src directory a Python package 2 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # This file makes the test directory a Python package 2 | -------------------------------------------------------------------------------- /.beads/metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "database": "beads.db", 3 | "jsonl_export": "issues.jsonl" 4 | } -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | requests==2.32.4 2 | google-adk==1.10.0 3 | google-generativeai==0.8.5 4 | litellm==1.77.5 5 | boto3==1.38.19 6 | deprecated==1.2.14 7 | packaging==25.0 -------------------------------------------------------------------------------- /src/smartfix/shared/__init__.py: -------------------------------------------------------------------------------- 1 | """Shared Utilities and Common Functionality 2 | 3 | This package contains shared utilities, common functionality, and 4 | helper classes that are used across multiple domains. 5 | 6 | Key Components (to be implemented): 7 | - Configuration utilities and validation 8 | - Logging and telemetry helpers 9 | - Common data structures and types 10 | - Utility functions and decorators 11 | """ 12 | 13 | __all__ = [ 14 | # Shared utilities will be exported as they are implemented 15 | ] 16 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | """Test configuration and setup for pytest. 2 | 3 | This file is automatically loaded by pytest and sets up the Python path 4 | so that all test files can import from src without path manipulation. 5 | """ 6 | 7 | import sys 8 | from pathlib import Path 9 | 10 | # Add the project root to Python path so that 'src' imports work 11 | project_root = Path(__file__).parent.parent 12 | sys.path.insert(0, str(project_root)) 13 | 14 | # Add test directory to path for test helpers 15 | test_dir = Path(__file__).parent 16 | sys.path.insert(0, str(test_dir)) 17 | -------------------------------------------------------------------------------- /src/smartfix/domains/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | """Code Analysis Domain 2 | 3 | This domain handles build execution, code analysis, and result processing 4 | for vulnerability remediation workflows. 5 | 6 | Key Components (to be implemented): 7 | - BuildExecutor: Build command execution and management 8 | - CodeAnalyzer: Code change analysis and error extraction 9 | - BuildContext: Build environment and configuration 10 | - BuildResult: Structured build outcome representation 11 | """ 12 | 13 | __all__ = [ 14 | # Components will be exported as they are implemented 15 | ] 16 | -------------------------------------------------------------------------------- /src/smartfix/domains/workflow/__init__.py: -------------------------------------------------------------------------------- 1 | """Workflow Orchestration Domain 2 | 3 | This domain handles workflow execution, step management, and 4 | orchestration of the complete vulnerability remediation process. 5 | 6 | Key Components (to be implemented): 7 | - RemediationWorkflow: End-to-end remediation orchestration 8 | - WorkflowEngine: Step-based workflow execution framework 9 | - WorkflowContext: Context and state management for workflows 10 | - WorkflowStep: Individual workflow step definitions and execution 11 | """ 12 | 13 | __all__ = [ 14 | # Components will be exported as they are implemented 15 | ] 16 | -------------------------------------------------------------------------------- /src/smartfix/domains/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | """External Integration Domain 2 | 3 | This domain handles integrations with external services including 4 | the Contrast API, vulnerability fetching, and notification systems. 5 | 6 | Key Components (to be implemented): 7 | - ContrastApiClient: Contrast API integration and authentication 8 | - VulnerabilityFetcher: Vulnerability data retrieval from backend 9 | - RemediationNotifier: Lifecycle event notifications 10 | - LLMProviderClient: LLM provider abstraction interface 11 | """ 12 | 13 | __all__ = [ 14 | # Components will be exported as they are implemented 15 | ] 16 | -------------------------------------------------------------------------------- /src/smartfix/domains/scm/__init__.py: -------------------------------------------------------------------------------- 1 | """Source Control Management Domain 2 | 3 | This domain provides SCM-agnostic abstractions for repository operations, 4 | branch management, and pull request handling across different providers. 5 | 6 | Key Components (to be implemented): 7 | - Repository: Repository operations and workspace management 8 | - PullRequest: Pull request lifecycle and metadata management 9 | - Branch: Branch operations and state tracking 10 | - ScmProvider: Abstract interface for SCM provider implementations 11 | """ 12 | 13 | __all__ = [ 14 | # Components will be exported as they are implemented 15 | ] 16 | -------------------------------------------------------------------------------- /src/smartfix/telemetry/__init__.py: -------------------------------------------------------------------------------- 1 | """Telemetry and Cost Tracking System 2 | 3 | This package contains telemetry collection, cost analysis, and 4 | observability components that work across all SCM providers. 5 | 6 | Key Components (to be implemented): 7 | - TelemetryCollector: Main telemetry collection and processing 8 | - CostAnalyzer: LLM cost tracking and analysis 9 | - WorkflowTracker: Workflow execution tracking and metrics 10 | - PerformanceMonitor: System performance monitoring 11 | - ObservabilityClient: Integration with external monitoring systems 12 | """ 13 | 14 | __all__ = [ 15 | # Telemetry components will be exported as they are implemented 16 | ] 17 | -------------------------------------------------------------------------------- /src/smartfix/config/__init__.py: -------------------------------------------------------------------------------- 1 | """Configuration Management System 2 | 3 | This package contains configuration management components including 4 | base configuration classes, domain-specific configurations, and 5 | dependency injection functionality. 6 | 7 | Key Components (to be implemented): 8 | - SmartFixConfig: Main configuration aggregate root 9 | - BaseConfig: Abstract base for configuration classes 10 | - AgentConfig: AI agent configuration and settings 11 | - BuildConfig: Build system configuration 12 | - ScmConfig: Source control management configuration 13 | - TelemetryConfig: Telemetry and observability configuration 14 | """ 15 | 16 | __all__ = [ 17 | # Configuration components will be exported as they are implemented 18 | ] 19 | -------------------------------------------------------------------------------- /src/smartfix/domains/__init__.py: -------------------------------------------------------------------------------- 1 | """SmartFix Domain-Driven Design Modules 2 | 3 | This package contains the core domain modules organized according to 4 | Domain-Driven Design principles. Each domain represents a distinct 5 | business capability within the SmartFix system. 6 | 7 | Domains: 8 | - vulnerability: Vulnerability data models and processing logic 9 | - agents: AI agent orchestration and coding strategies 10 | - analysis: Code analysis, build execution, and result processing 11 | - scm: Source control management abstractions and providers 12 | - integrations: External service integrations (APIs, notifications) 13 | - workflow: Orchestration and workflow execution engine 14 | """ 15 | 16 | __all__ = [ 17 | # Domain modules will be exported as they are implemented 18 | ] 19 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for contrast-ai-smartfix-action 2 | # 3 | # Available targets: 4 | # make test - Run all tests using the test runner 5 | # make lint - Run linting via pre-push hook 6 | # make help - Show this help message 7 | 8 | .PHONY: test lint help 9 | 10 | # Run tests using the test runner script 11 | test: 12 | @echo "Running tests..." 13 | ./test/run_tests.sh 14 | 15 | # Run linting via the pre-push hook 16 | lint: 17 | @echo "Running linter..." 18 | ./.git/hooks/pre-push 19 | 20 | # Show help message 21 | help: 22 | @echo "Available make targets:" 23 | @echo " test - Run all tests using ./test/run_tests.sh" 24 | @echo " lint - Run linting via ./.git/hooks/pre-push" 25 | @echo " help - Show this help message" 26 | 27 | # Default target 28 | .DEFAULT_GOAL := help -------------------------------------------------------------------------------- /test/setup_test_env.py: -------------------------------------------------------------------------------- 1 | """Test environment helper utilities. 2 | 3 | This module provides helper functions for test setup, including 4 | temporary directory management. 5 | """ 6 | 7 | from pathlib import Path 8 | 9 | 10 | def create_temp_repo_dir(): 11 | """ 12 | Create a temporary directory for repository testing. 13 | 14 | Returns: 15 | pathlib.Path: Path to temporary directory 16 | """ 17 | import tempfile 18 | return Path(tempfile.mkdtemp()) 19 | 20 | 21 | def cleanup_temp_dir(temp_dir): 22 | """ 23 | Clean up temporary directory. 24 | 25 | Args: 26 | temp_dir: Path to temporary directory to clean up 27 | """ 28 | import shutil 29 | if temp_dir and temp_dir.exists(): 30 | shutil.rmtree(temp_dir, ignore_errors=True) 31 | -------------------------------------------------------------------------------- /src/smartfix/domains/agents/__init__.py: -------------------------------------------------------------------------------- 1 | """AI Agent Orchestration Domain 2 | 3 | This domain handles AI agent management, coding strategies, and 4 | agent coordination for vulnerability remediation. 5 | 6 | Key Components: 7 | - SmartFixAgent: Main agent interface and implementation 8 | - AgentSession: Stateful agent interaction management 9 | - CodingAgentStrategy: Strategy pattern for different coding agents 10 | - AgentFactory: Factory for creating and configuring agents 11 | """ 12 | 13 | from .coding_agent import CodingAgentStrategy 14 | from src.smartfix.shared.coding_agents import CodingAgents 15 | from .smartfix_agent import SmartFixAgent 16 | from .agent_factory import AgentFactory 17 | from .agent_session import AgentSession 18 | 19 | __all__ = [ 20 | 'CodingAgentStrategy', 21 | 'CodingAgents', 22 | 'SmartFixAgent', 23 | 'AgentFactory', 24 | 'AgentSession', 25 | ] 26 | -------------------------------------------------------------------------------- /src/github/__init__.py: -------------------------------------------------------------------------------- 1 | """GitHub Provider Implementation 2 | 3 | This package contains GitHub-specific implementations for the SmartFix system, 4 | including SCM provider implementations, API clients, and GitHub Action integrations. 5 | 6 | Key Components: 7 | - GitHubScmProvider: GitHub implementation of the ScmProvider interface 8 | - GitHubApiClient: GitHub API integration and operations 9 | - ExternalCodingAgent: GitHub Copilot integration (moved from src/) 10 | """ 11 | 12 | # Import classes for easy access 13 | try: 14 | from .external_coding_agent import ExternalCodingAgent # noqa: F401 15 | 16 | __all__ = [ 17 | "ExternalCodingAgent", 18 | ] 19 | except ImportError: 20 | # During development, dependencies may not be available 21 | __all__ = [] 22 | 23 | # TODO: Add other GitHub components as they are implemented: 24 | # - GitHubScmProvider 25 | # - GitHubApiClient 26 | -------------------------------------------------------------------------------- /security.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Issues 2 | 3 | Contrast takes security vulnerabilities seriously. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions. 4 | 5 | To report a security issue, please see our official [Vulnerability Disclosure Policy 6 | ](https://www.contrastsecurity.com/disclosure-policy) 7 | 8 | Contrast will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance. 9 | 10 | Report security bugs in third-party modules to the person or team maintaining the module. 11 | 12 | ## Learning More About Security 13 | 14 | To learn more about securing your applications with Contrast, please see the [our docs](https://docs.contrastsecurity.com/?lang=en). 15 | -------------------------------------------------------------------------------- /test/test_setup_test_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Tests for test environment setup helper.""" 3 | 4 | import unittest 5 | 6 | # Test setup imports (path is set up by conftest.py) 7 | from setup_test_env import ( 8 | create_temp_repo_dir, 9 | cleanup_temp_dir 10 | ) 11 | 12 | 13 | class TestSetupTestEnv(unittest.TestCase): 14 | """Test cases for test environment setup helper.""" 15 | 16 | def test_create_and_cleanup_temp_dir(self): 17 | """Test temporary directory creation and cleanup.""" 18 | # Create temp directory 19 | temp_dir = create_temp_repo_dir() 20 | 21 | # Should exist and be a directory 22 | self.assertTrue(temp_dir.exists()) 23 | self.assertTrue(temp_dir.is_dir()) 24 | 25 | # Clean up 26 | cleanup_temp_dir(temp_dir) 27 | 28 | # Should no longer exist 29 | self.assertFalse(temp_dir.exists()) 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /src/smartfix/shared/llm_providers.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from enum import Enum 21 | 22 | 23 | class LlmProvider(Enum): 24 | """Enumeration of LLM provider types.""" 25 | CONTRAST = "CONTRAST" 26 | BYOLLM = "BYOLLM" 27 | -------------------------------------------------------------------------------- /src/smartfix/shared/coding_agents.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from enum import Enum 21 | 22 | 23 | class CodingAgents(Enum): 24 | """Enumeration of available coding agent types.""" 25 | SMARTFIX = "SMARTFIX" 26 | GITHUB_COPILOT = "GITHUB_COPILOT" 27 | CLAUDE_CODE = "CLAUDE_CODE" 28 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Configuration for flake8 Python linter 3 | 4 | # Maximum line length 5 | max-line-length = 180 6 | 7 | # Exclude directories and files from linting 8 | exclude = 9 | .git, 10 | __pycache__, 11 | .venv, 12 | venv, 13 | env, 14 | .env, 15 | build, 16 | dist, 17 | *.egg-info 18 | 19 | # Ignore specific error codes 20 | ignore = 21 | # E203: whitespace before ':' (conflicts with black formatter) 22 | E203, 23 | # W503: line break before binary operator (conflicts with black formatter) 24 | W503, 25 | # E501: line too long (we set max-line-length instead) 26 | # E501 27 | 28 | # Select specific error codes to check (optional - if not specified, checks most things) 29 | select = 30 | E, # pycodestyle errors 31 | W, # pycodestyle warnings 32 | F, # pyflakes 33 | C, # mccabe complexity 34 | 35 | # Maximum complexity allowed 36 | max-complexity = 15 37 | 38 | # Show the source code for each error 39 | show-source = True 40 | 41 | # Count the number of occurrences of each error/warning code 42 | statistics = True 43 | 44 | # Enable showing the pep8 source for each error 45 | show-pep8 = True 46 | -------------------------------------------------------------------------------- /src/smartfix/extensions/__init__.py: -------------------------------------------------------------------------------- 1 | """SmartFix LLM Extensions 2 | 3 | This package contains enhanced LLM integrations and extensions 4 | that provide advanced functionality like prompt caching, cost tracking, 5 | and multi-provider support. 6 | 7 | Key Components: 8 | - SmartFixLiteLlm: Enhanced LiteLLM with prompt caching and cost tracking 9 | - SmartFixLlmAgent: Enhanced LLM agent with advanced capabilities 10 | """ 11 | 12 | import warnings 13 | 14 | # Suppress specific Pydantic field shadowing warnings from ADK library 15 | # This suppresses warnings about SequentialAgent.config_type shadowing BaseAgent.config_type 16 | warnings.filterwarnings('ignore', message='.*config_type.*shadows.*', category=UserWarning) 17 | warnings.filterwarnings('ignore', message='.*shadows an attribute.*', category=UserWarning) 18 | 19 | # Import classes for easy access 20 | try: 21 | from src.smartfix.extensions.smartfix_litellm import SmartFixLiteLlm # noqa: F401 22 | from src.smartfix.extensions.smartfix_llm_agent import SmartFixLlmAgent # noqa: F401 23 | 24 | __all__ = [ 25 | "SmartFixLiteLlm", 26 | "SmartFixLlmAgent", 27 | ] 28 | except ImportError: 29 | # During development, dependencies may not be available 30 | __all__ = [] 31 | -------------------------------------------------------------------------------- /src/smartfix/domains/vulnerability/__init__.py: -------------------------------------------------------------------------------- 1 | """Vulnerability Processing Domain 2 | 3 | This domain handles vulnerability data models, remediation contexts, 4 | and vulnerability processing logic independent of specific SCM providers. 5 | 6 | Key Components: 7 | - Vulnerability: Core vulnerability data model with severity and status management 8 | - RemediationContext: Context for vulnerability remediation with configuration and creation methods 9 | - VulnerabilitySeverity: Enumeration of vulnerability severity levels 10 | - PromptConfiguration: AI prompts configuration for remediation 11 | - BuildConfiguration: Build and testing configuration 12 | - RepositoryConfiguration: Repository and SCM configuration 13 | """ 14 | 15 | # Import core domain classes 16 | from .models import ( 17 | Vulnerability, 18 | VulnerabilitySeverity 19 | ) 20 | 21 | from .context import ( 22 | RemediationContext, 23 | PromptConfiguration, 24 | BuildConfiguration, 25 | RepositoryConfiguration 26 | ) 27 | 28 | __all__ = [ 29 | # Core vulnerability models 30 | "Vulnerability", 31 | "VulnerabilitySeverity", 32 | 33 | # Remediation context and configuration 34 | "RemediationContext", 35 | "PromptConfiguration", 36 | "BuildConfiguration", 37 | "RepositoryConfiguration", 38 | ] 39 | -------------------------------------------------------------------------------- /src/smartfix/shared/failure_categories.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from enum import Enum 21 | 22 | 23 | class FailureCategory(Enum): 24 | """Define failure categories as an enum to ensure consistency.""" 25 | INITIAL_BUILD_FAILURE = "INITIAL_BUILD_FAILURE" 26 | EXCEEDED_QA_ATTEMPTS = "EXCEEDED_QA_ATTEMPTS" 27 | QA_AGENT_FAILURE = "QA_AGENT_FAILURE" 28 | GIT_COMMAND_FAILURE = "GIT_COMMAND_FAILURE" 29 | AGENT_FAILURE = "AGENT_FAILURE" 30 | GENERATE_PR_FAILURE = "GENERATE_PR_FAILURE" 31 | GENERAL_FAILURE = "GENERAL_FAILURE" 32 | EXCEEDED_TIMEOUT = "EXCEEDED_TIMEOUT" 33 | EXCEEDED_AGENT_EVENTS = "EXCEEDED_AGENT_EVENTS" 34 | INVALID_LLM_CONFIG = "INVALID_LLM_CONFIG" 35 | -------------------------------------------------------------------------------- /src/github/github_scm_provider.py: -------------------------------------------------------------------------------- 1 | """GitHub SCM Provider Implementation 2 | 3 | This module provides the GitHub-specific implementation of the ScmProvider interface, 4 | including GitHub API operations, repository management, and GitHub Action integration. 5 | """ 6 | 7 | # TODO: Import will be added when ScmProvider interface is implemented 8 | # from src.smartfix.domains.scm.provider import ScmProvider 9 | 10 | 11 | class GitHubScmProvider: 12 | """ 13 | GitHub implementation of the ScmProvider interface. 14 | 15 | This class provides GitHub-specific implementations for all SCM operations 16 | including repository management, branch operations, pull request handling, 17 | and GitHub API integration. 18 | 19 | Key Responsibilities: 20 | - GitHub repository operations (clone, branch, commit) 21 | - Pull request creation and management 22 | - GitHub API authentication and rate limiting 23 | - GitHub-specific error handling and recovery 24 | - Integration with GitHub Actions environment 25 | 26 | TODO: Implement ScmProvider interface when available 27 | """ 28 | 29 | def __init__(self): 30 | """Initialize GitHub SCM provider with configuration and API client.""" 31 | # TODO: Implementation will be added in Task 4.2.1 32 | pass 33 | 34 | # TODO: Implement all ScmProvider interface methods: 35 | # - create_repository() 36 | # - create_branch() 37 | # - commit_changes() 38 | # - create_pull_request() 39 | # - etc. 40 | -------------------------------------------------------------------------------- /src/smartfix/domains/providers/__init__.py: -------------------------------------------------------------------------------- 1 | import litellm 2 | import os 3 | from src.config import get_config 4 | from src.utils import normalize_host 5 | 6 | config = get_config() 7 | 8 | # Contrast LLM model constants 9 | CONTRAST_CLAUDE_SONNET_4_5 = "contrast/claude-sonnet-4-5" 10 | 11 | 12 | def setup_contrast_provider(): 13 | """Setup Contrast Bedrock proxy as a custom provider.""" 14 | 15 | # Register the model with litellm 16 | litellm.register_model({ 17 | CONTRAST_CLAUDE_SONNET_4_5: { 18 | # Model capabilities 19 | "max_tokens": 8192, 20 | "max_input_tokens": 200000, 21 | "max_output_tokens": 8192, 22 | 23 | # Pricing (per token) 24 | "input_cost_per_token": 0.000003, 25 | "output_cost_per_token": 0.000015, 26 | 27 | # Prompt caching pricing 28 | "cache_creation_input_token_cost": 0.00000375, 29 | "cache_read_input_token_cost": 0.0000003, 30 | 31 | # Provider configuration 32 | "litellm_provider": "anthropic", # Use Anthropic provider 33 | "mode": "chat", 34 | 35 | # Feature support 36 | "supports_function_calling": True, 37 | "supports_vision": True, 38 | "supports_prompt_caching": True, 39 | } 40 | }) 41 | 42 | # Configure to use Contrast proxy (still uses Anthropic API format) 43 | os.environ["ANTHROPIC_API_BASE"] = f"https://{normalize_host(config.CONTRAST_HOST)}/api/v4/llm-proxy/organizations/{config.CONTRAST_ORG_ID}" 44 | os.environ["ANTHROPIC_API_KEY"] = f"{config.CONTRAST_API_KEY}" 45 | -------------------------------------------------------------------------------- /src/github/github_api_client.py: -------------------------------------------------------------------------------- 1 | """GitHub API Client Implementation 2 | 3 | This module provides comprehensive GitHub API integration including authentication, 4 | API operations, rate limiting, and error handling for SmartFix operations. 5 | """ 6 | 7 | 8 | class GitHubApiClient: 9 | """ 10 | GitHub API client for SmartFix operations. 11 | 12 | This class handles all GitHub API interactions including authentication, 13 | rate limiting, error handling, and provides high-level operations for 14 | SmartFix functionality. 15 | 16 | Key Responsibilities: 17 | - GitHub API authentication (tokens, GitHub Actions) 18 | - Repository operations (create, clone, delete) 19 | - Pull request operations (create, update, merge, close) 20 | - Issue operations (create, update, assign, label) 21 | - Branch operations (create, delete, compare) 22 | - API rate limiting and optimization 23 | - Comprehensive error handling and retry logic 24 | 25 | Features: 26 | - Support for both GitHub.com and GitHub Enterprise 27 | - Automatic rate limit handling 28 | - Request/response logging and debugging 29 | - Connection pooling and caching 30 | """ 31 | 32 | def __init__(self, token: str = None, base_url: str = "https://api.github.com"): 33 | """ 34 | Initialize GitHub API client. 35 | 36 | Args: 37 | token: GitHub authentication token 38 | base_url: GitHub API base URL (for Enterprise support) 39 | """ 40 | # TODO: Implementation will be added in Task 4.2.3 41 | self.token = token 42 | self.base_url = base_url 43 | 44 | # TODO: Implement API methods: 45 | # - authenticate() 46 | # - create_repository() 47 | # - create_pull_request() 48 | # - create_issue() 49 | # - add_labels() 50 | # - assign_reviewers() 51 | # - etc. 52 | -------------------------------------------------------------------------------- /src/smartfix/domains/agents/coding_agent.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security’s commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from abc import ABC, abstractmethod 21 | 22 | from src.smartfix.domains.vulnerability import RemediationContext 23 | from .agent_session import AgentSession 24 | 25 | 26 | class CodingAgentStrategy(ABC): 27 | """ 28 | Abstract base class defining the interface for a coding agent. 29 | This allows for a "plug-and-play" architecture where different agents 30 | (e.g., SmartFix internal, GitHub Copilot) can be used interchangeably. 31 | """ 32 | 33 | @abstractmethod 34 | def remediate(self, context: RemediationContext) -> AgentSession: 35 | """ 36 | The primary entry point for an agent to attempt a remediation. 37 | 38 | Args: 39 | context: The remediation context containing all necessary information 40 | about the vulnerability, repository, and configuration. 41 | 42 | Returns: 43 | An AgentSession object containing the complete remediation attempt, 44 | including success status, events, costs, and final PR content. 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /src/smartfix/domains/agents/agent_factory.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from .coding_agent import CodingAgentStrategy 21 | from src.smartfix.shared.coding_agents import CodingAgents 22 | from .smartfix_agent import SmartFixAgent 23 | 24 | 25 | class AgentFactory: 26 | """ 27 | Factory class for creating and configuring coding agent instances. 28 | 29 | Provides centralized agent creation with strategy selection logic. 30 | Configuration is passed through RemediationContext, not at agent creation. 31 | """ 32 | 33 | @staticmethod 34 | def create_agent(agent_type: CodingAgents) -> CodingAgentStrategy: 35 | """ 36 | Create a coding agent instance based on the specified type. 37 | 38 | Args: 39 | agent_type: The type of agent to create 40 | 41 | Returns: 42 | CodingAgentStrategy: Configured coding agent instance 43 | 44 | Raises: 45 | ValueError: If agent_type is not supported 46 | """ 47 | if agent_type == CodingAgents.SMARTFIX: 48 | return SmartFixAgent() 49 | else: 50 | raise ValueError(f"Domain factory only supports SMARTFIX agents. Got: {agent_type}") 51 | -------------------------------------------------------------------------------- /src/smartfix/domains/agents/agent_session.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security’s commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from dataclasses import dataclass 21 | from typing import Optional 22 | 23 | from src.smartfix.shared.failure_categories import FailureCategory 24 | 25 | 26 | @dataclass 27 | class AgentSession: 28 | """ 29 | Tracks the state and history of a single, complete remediation attempt. 30 | """ 31 | qa_attempts: int = 0 32 | final_pr_body: Optional[str] = None 33 | failure_category: Optional[FailureCategory] = None 34 | is_complete: bool = False 35 | 36 | def complete_session(self, failure_category: Optional[FailureCategory] = None, 37 | pr_body: Optional[str] = None) -> None: 38 | """Marks the session as complete with optional failure category.""" 39 | self.failure_category = failure_category 40 | self.final_pr_body = pr_body 41 | self.is_complete = True 42 | 43 | @property 44 | def success(self) -> bool: 45 | """Returns True if the session completed successfully.""" 46 | return self.is_complete and self.failure_category is None 47 | 48 | @property 49 | def pr_body(self) -> Optional[str]: 50 | """Returns the final PR body content.""" 51 | return self.final_pr_body 52 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' # Run on all branches 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.11' 17 | cache: 'pip' # Enable built-in pip caching 18 | 19 | # Install uv tool for dependency management 20 | - name: Install uv 21 | shell: bash 22 | run: | 23 | echo "" 24 | echo "📦 Installing uv package manager..." 25 | if [ "$RUNNER_OS" = "Windows" ]; then 26 | pip install uv==0.7.15 > uv-install.log || (cat uv-install.log && exit 1) 27 | else 28 | pip install uv==0.7.15 > /tmp/uv-install.log || (cat /tmp/uv-install.log && exit 1) 29 | fi 30 | echo "✅ uv installed successfully" 31 | 32 | # Install dependencies using uv (Linux/macOS) 33 | - name: Install dependencies with uv 34 | if: runner.os == 'Linux' || runner.os == 'macOS' 35 | shell: bash 36 | run: | 37 | echo "" 38 | echo "📦 Installing Python packages using uv..." 39 | echo "📎 Using lockfile for deterministic installation" 40 | uv pip sync --system ${{ github.workspace }}/src/requirements.lock > /tmp/uv-output.log || (cat /tmp/uv-output.log && exit 1) 41 | echo "✅ Python dependencies installed successfully" 42 | 43 | - name: Run Python linting 44 | shell: bash 45 | run: | 46 | echo "" 47 | echo "🔍 Running Python linting using pre-push hook..." 48 | # Use the pre-push hook as single source of truth for linting 49 | ./hooks/pre-push 50 | 51 | - name: Run tests 52 | shell: bash 53 | run: | 54 | echo "" 55 | echo "🧪 Running tests using test runner script..." 56 | # Use the test script as single source of truth for testing 57 | ./test/run_tests.sh --skip-install 58 | env: 59 | GITHUB_TOKEN: ${{ github.token }} 60 | -------------------------------------------------------------------------------- /hooks/pre-commit: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Pre-commit hook to clean whitespace 4 | # This script will clean whitespace and include it in the commit 5 | 6 | set -e 7 | 8 | echo "🔍 Running pre-commit checks..." 9 | 10 | # Colors for output 11 | RED='\033[0;31m' 12 | GREEN='\033[0;32m' 13 | YELLOW='\033[1;33m' 14 | BLUE='\033[0;34m' 15 | NC='\033[0m' # No Color 16 | 17 | # Function to print colored output 18 | print_error() { 19 | echo -e "${RED}❌ $1${NC}" 20 | } 21 | 22 | print_success() { 23 | echo -e "${GREEN}✅ $1${NC}" 24 | } 25 | 26 | print_warning() { 27 | echo -e "${YELLOW}⚠️ $1${NC}" 28 | } 29 | 30 | print_info() { 31 | echo -e "${BLUE}ℹ️ $1${NC}" 32 | } 33 | 34 | # Check if we're in the right directory 35 | if [ ! -f "action.yml" ] || [ ! -d "src" ]; then 36 | print_error "This doesn't appear to be the contrast-ai-smartfix-action repository root" 37 | exit 1 38 | fi 39 | 40 | # Clean trailing whitespace from staged files 41 | echo "" 42 | print_info "Cleaning trailing whitespace from staged files..." 43 | 44 | # Get list of staged files 45 | STAGED_FILES=$(git diff --cached --name-only --diff-filter=ACM) 46 | 47 | if [ -n "$STAGED_FILES" ]; then 48 | echo "📁 Processing staged files:" 49 | echo "$STAGED_FILES" | sed 's/^/ /' 50 | 51 | WHITESPACE_CLEANED=false 52 | 53 | # Apply sed command to each staged file 54 | for file in $STAGED_FILES; do 55 | if [ -f "$file" ]; then 56 | echo " 🧹 Cleaning whitespace in: $file" 57 | # Clean trailing whitespace 58 | sed -i '' 's/[[:space:]]*$//' "$file" 59 | # Re-stage the file with cleaned whitespace 60 | git add "$file" 61 | WHITESPACE_CLEANED=true 62 | fi 63 | done 64 | 65 | if [ "$WHITESPACE_CLEANED" = true ]; then 66 | print_success "Whitespace cleanup completed and re-staged!" 67 | else 68 | print_info "No files needed whitespace cleanup" 69 | fi 70 | else 71 | print_info "No staged files detected for whitespace cleanup" 72 | fi 73 | 74 | echo "" 75 | print_success "All pre-commit checks completed successfully!" 76 | echo "✨ Commit proceeding..." 77 | -------------------------------------------------------------------------------- /hooks/pre-push: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Pre-push hook to run Python linting 4 | # This script will block the push if linting fails 5 | 6 | set -e 7 | 8 | echo "🔍 Running pre-push checks..." 9 | 10 | # Colors for output 11 | RED='\033[0;31m' 12 | GREEN='\033[0;32m' 13 | YELLOW='\033[1;33m' 14 | BLUE='\033[0;34m' 15 | NC='\033[0m' # No Color 16 | 17 | # Function to print colored output 18 | print_error() { 19 | echo -e "${RED}❌ $1${NC}" 20 | } 21 | 22 | print_success() { 23 | echo -e "${GREEN}✅ $1${NC}" 24 | } 25 | 26 | print_warning() { 27 | echo -e "${YELLOW}⚠️ $1${NC}" 28 | } 29 | 30 | print_info() { 31 | echo -e "${BLUE}ℹ️ $1${NC}" 32 | } 33 | 34 | # Check if we're in the right directory 35 | if [ ! -f "action.yml" ] || [ ! -d "src" ]; then 36 | print_error "This doesn't appear to be the contrast-ai-smartfix-action repository root" 37 | exit 1 38 | fi 39 | 40 | # Run Python linting 41 | echo "" 42 | print_info "Running Python linter..." 43 | 44 | # Check if flake8 is installed, if not try to install it 45 | if ! command -v flake8 &> /dev/null; then 46 | print_warning "flake8 not found. Attempting to install..." 47 | if command -v pip &> /dev/null; then 48 | pip install flake8 49 | elif command -v pip3 &> /dev/null; then 50 | pip3 install flake8 51 | else 52 | print_error "Neither pip nor pip3 found. Please install flake8 manually:" 53 | print_error "pip install flake8" 54 | exit 1 55 | fi 56 | fi 57 | 58 | # Find all Python files in src/ and test/ directories 59 | PYTHON_FILES=$(find src/ test/ -name "*.py" 2>/dev/null || true) 60 | 61 | if [ -z "$PYTHON_FILES" ]; then 62 | print_warning "No Python files found in src/ or test/ directories" 63 | exit 0 64 | fi 65 | 66 | echo "📁 Found Python files:" 67 | echo "$PYTHON_FILES" | sed 's/^/ /' 68 | 69 | # Run flake8 on the Python files 70 | echo "" 71 | echo "🧹 Running flake8 linter..." 72 | 73 | if flake8 $PYTHON_FILES; then 74 | print_success "All Python files passed linting!" 75 | echo "" 76 | print_success "Pre-push linting check completed successfully!" 77 | echo "🚀 Push proceeding..." 78 | exit 0 79 | else 80 | echo "" 81 | print_error "Linting failed! Push blocked." 82 | echo "" 83 | echo "💡 To fix linting issues:" 84 | echo " 1. Review the errors above" 85 | echo " 2. Fix the issues in your code" 86 | echo " 3. Commit your fixes" 87 | echo " 4. Try pushing again" 88 | echo "" 89 | echo "🔧 To skip all pre-push checks (not recommended):" 90 | echo " git push --no-verify" 91 | echo "" 92 | exit 1 93 | fi 94 | -------------------------------------------------------------------------------- /src/github/agent_factory.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | from typing import Optional 21 | from src.config import Config 22 | from src.smartfix.domains.agents.coding_agent import CodingAgentStrategy 23 | from src.smartfix.shared.coding_agents import CodingAgents 24 | from src.smartfix.domains.agents.agent_factory import AgentFactory 25 | from .external_coding_agent import ExternalCodingAgent 26 | 27 | 28 | class GitHubAgentFactory(AgentFactory): 29 | """ 30 | GitHub-specific agent factory that extends the domain factory capabilities. 31 | 32 | Provides creation of both domain agents (SMARTFIX) and GitHub-specific 33 | external agents (GITHUB_COPILOT, CLAUDE_CODE). 34 | """ 35 | 36 | @staticmethod 37 | def create_agent(agent_type: CodingAgents, config: Optional[Config] = None) -> CodingAgentStrategy: 38 | """ 39 | Create a coding agent instance based on the specified type. 40 | 41 | Supports both domain agents and GitHub-specific external agents. 42 | 43 | Args: 44 | agent_type: The type of agent to create 45 | config: The application configuration object (required for external agents, not for SMARTFIX) 46 | 47 | Returns: 48 | CodingAgentStrategy: Configured coding agent instance 49 | 50 | Raises: 51 | ValueError: If agent_type is not supported or config is missing for external agents 52 | """ 53 | # Delegate domain agents to the parent factory (no config needed) 54 | if agent_type == CodingAgents.SMARTFIX: 55 | return AgentFactory.create_agent(agent_type) 56 | 57 | # Handle GitHub-specific external agents (config required) 58 | elif agent_type in (CodingAgents.GITHUB_COPILOT, CodingAgents.CLAUDE_CODE): 59 | if config is None: 60 | raise ValueError(f"Config is required for external agent type: {agent_type}") 61 | return ExternalCodingAgent(config) 62 | 63 | else: 64 | raise ValueError(f"Unsupported agent type: {agent_type}") 65 | -------------------------------------------------------------------------------- /setup-hooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Setup script for installing git hooks 4 | # Run this once after cloning the repository to enable automatic linting 5 | 6 | set -e 7 | 8 | echo "🔧 Setting up git hooks for SmartFix development..." 9 | 10 | # Colors for output 11 | RED='\033[0;31m' 12 | GREEN='\033[0;32m' 13 | YELLOW='\033[1;33m' 14 | BLUE='\033[0;34m' 15 | NC='\033[0m' # No Color 16 | 17 | # Function to print colored output 18 | print_error() { 19 | echo -e "${RED}❌ $1${NC}" 20 | } 21 | 22 | print_success() { 23 | echo -e "${GREEN}✅ $1${NC}" 24 | } 25 | 26 | print_warning() { 27 | echo -e "${YELLOW}⚠️ $1${NC}" 28 | } 29 | 30 | print_info() { 31 | echo -e "${BLUE}ℹ️ $1${NC}" 32 | } 33 | 34 | # Check if we're in the right directory 35 | if [ ! -f "action.yml" ] || [ ! -d "src" ] || [ ! -d "hooks" ]; then 36 | print_error "This doesn't appear to be the contrast-ai-smartfix-action repository root" 37 | print_error "Please run this script from the repository root directory" 38 | exit 1 39 | fi 40 | 41 | # Check if .git directory exists 42 | if [ ! -d ".git" ]; then 43 | print_error "No .git directory found. Are you in a git repository?" 44 | exit 1 45 | fi 46 | 47 | echo "" 48 | print_info "Installing git hooks..." 49 | 50 | # Install pre-commit hook 51 | if [ -f "hooks/pre-commit" ]; then 52 | ln -sf ../../hooks/pre-commit .git/hooks/pre-commit 53 | chmod +x .git/hooks/pre-commit 54 | print_success "Pre-commit hook installed (whitespace cleanup)" 55 | else 56 | print_warning "hooks/pre-commit not found, skipping" 57 | fi 58 | 59 | # Install pre-push hook 60 | if [ -f "hooks/pre-push" ]; then 61 | ln -sf ../../hooks/pre-push .git/hooks/pre-push 62 | chmod +x .git/hooks/pre-push 63 | print_success "Pre-push hook installed (Python linting)" 64 | else 65 | print_warning "hooks/pre-push not found, skipping" 66 | fi 67 | 68 | echo "" 69 | print_info "Checking Python linting dependencies..." 70 | 71 | # Check if flake8 is available 72 | if command -v flake8 &> /dev/null; then 73 | print_success "flake8 is already installed" 74 | else 75 | print_warning "flake8 not found. Installing..." 76 | if command -v pip &> /dev/null; then 77 | pip install flake8 78 | print_success "flake8 installed successfully" 79 | elif command -v pip3 &> /dev/null; then 80 | pip3 install flake8 81 | print_success "flake8 installed successfully" 82 | else 83 | print_error "Neither pip nor pip3 found. Please install flake8 manually:" 84 | print_error "pip install flake8" 85 | exit 1 86 | fi 87 | fi 88 | 89 | echo "" 90 | print_success "Git hooks setup completed successfully!" 91 | echo "" 92 | print_info "What happens now:" 93 | echo " 🧹 Pre-commit: Automatically cleans trailing whitespace" 94 | echo " 🔍 Pre-push: Runs Python linting before pushing" 95 | echo "" 96 | print_info "To bypass hooks temporarily (not recommended):" 97 | echo " git commit --no-verify" 98 | echo " git push --no-verify" 99 | echo "" 100 | print_success "Happy coding! 🚀" -------------------------------------------------------------------------------- /src/smartfix/domains/workflow/formatter.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | """ 21 | Formatter Module 22 | 23 | Handles code formatting operations for vulnerability remediation. 24 | """ 25 | 26 | import os 27 | from pathlib import Path 28 | from typing import Optional, List 29 | 30 | from src.utils import debug_log, log, error_exit, run_command 31 | 32 | 33 | def run_formatting_command(formatting_command: Optional[str], repo_root: Path, remediation_id: str) -> List[str]: 34 | """ 35 | Runs the formatting command if provided. 36 | 37 | Args: 38 | formatting_command: The formatting command to run (or None). 39 | repo_root: The repository root path. 40 | remediation_id: Remediation ID for error tracking. 41 | 42 | Returns: 43 | List[str]: List of files changed by the formatting command, empty list if none or no command. 44 | """ 45 | changed_files = [] 46 | if not formatting_command: 47 | return changed_files 48 | 49 | log(f"\n--- Running Formatting Command: {formatting_command} ---") 50 | # Modified to match run_command signature which returns only stdout 51 | current_dir = os.getcwd() 52 | try: 53 | os.chdir(str(repo_root)) # Change to repo root directory 54 | try: 55 | format_output = run_command( 56 | formatting_command.split(), # Split string into list for Popen 57 | check=False # Don't exit on failure, we'll check status 58 | ) 59 | format_success = True # If no exception was raised, consider it successful 60 | except Exception as e: 61 | format_success = False 62 | format_output = str(e) 63 | finally: 64 | os.chdir(current_dir) # Change back to original directory 65 | 66 | if format_success: 67 | debug_log("Formatting command successful.") 68 | # NOTE: Git operations are handled by main.py after all agent work completes 69 | # We just track which files were changed by the formatter 70 | # The formatter modifies files in place, so we don't need to commit here 71 | log("Formatting command completed.") 72 | else: 73 | log(f"::error::Error executing formatting command: {formatting_command}") 74 | log(f"::error::Error details: {format_output}", is_error=True) 75 | error_exit(remediation_id) 76 | 77 | return changed_files 78 | -------------------------------------------------------------------------------- /src/smartfix/domains/workflow/build_runner.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | """ 21 | Build Runner Module 22 | 23 | Handles execution of build commands and test commands for vulnerability remediation. 24 | """ 25 | 26 | import subprocess 27 | from pathlib import Path 28 | from typing import Tuple 29 | 30 | from src.utils import debug_log, log, error_exit 31 | import src.telemetry_handler as telemetry_handler 32 | 33 | 34 | def run_build_command(command: str, repo_root: Path, remediation_id: str) -> Tuple[bool, str]: 35 | """ 36 | Runs the specified build command in the repository root. 37 | 38 | Args: 39 | command: The build command string (e.g., "mvn clean install"). 40 | repo_root: The Path object representing the repository root directory. 41 | remediation_id: Remediation ID for error tracking. 42 | 43 | Returns: 44 | A tuple containing: 45 | - bool: True if the command succeeded (exit code 0), False otherwise. 46 | - str: The combined stdout and stderr output of the command. 47 | """ 48 | log(f"\n--- Running Build Command: {command} ---") 49 | try: 50 | # Use shell=True if the command might contain shell operators like &&, ||, > etc. 51 | # Be cautious with shell=True if the command comes from untrusted input. 52 | # Here, it's from an environment variable, assumed to be controlled. 53 | result = subprocess.run( 54 | command, 55 | cwd=repo_root, 56 | shell=True, 57 | check=False, # Don't raise exception on non-zero exit 58 | capture_output=True, 59 | text=True, 60 | encoding='utf-8', # Explicitly set encoding 61 | errors='replace' # Handle potential encoding errors in output 62 | ) 63 | telemetry_handler.update_telemetry("configInfo.buildCommandRunTestsIncluded", True) 64 | output = result.stdout + result.stderr 65 | if result.returncode == 0: 66 | log("Build command succeeded.") 67 | return True, output 68 | else: 69 | debug_log(f"Build command failed with exit code {result.returncode}.") 70 | 71 | return False, output 72 | except FileNotFoundError: 73 | log(f"Error: Build command '{command}' not found. Is it installed and in PATH?", is_error=True) 74 | error_exit(remediation_id) 75 | except Exception as e: 76 | log(f"An unexpected error occurred while running the build command: {e}", is_error=True) 77 | error_exit(remediation_id) 78 | -------------------------------------------------------------------------------- /test/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # run_tests.sh - Install dependencies with UV and run tests 4 | # 5 | # Usage: 6 | # ./run_tests.sh [--skip-install] [test_files...] 7 | # 8 | # Examples: 9 | # ./run_tests.sh # Install deps and run all tests 10 | # ./run_tests.sh test_main.py # Install deps and run specific test 11 | # ./run_tests.sh --skip-install # Skip installation, run all tests 12 | # 13 | 14 | set -e # Exit on error 15 | 16 | # Get the project root directory 17 | PROJECT_ROOT=$(cd "$(dirname "$0")/.." && pwd) 18 | REQUIREMENTS_LOCK="$PROJECT_ROOT/src/requirements.lock" 19 | SKIP_INSTALL=0 20 | 21 | # Process arguments 22 | TEST_FILES=() 23 | for arg in "$@"; do 24 | if [[ "$arg" == "--skip-install" ]]; then 25 | SKIP_INSTALL=1 26 | else 27 | TEST_FILES+=("$arg") 28 | fi 29 | done 30 | 31 | # Change to project root for proper imports 32 | cd "$PROJECT_ROOT" 33 | 34 | # Install dependencies if not skipped 35 | if [[ $SKIP_INSTALL -eq 0 ]]; then 36 | if [[ ! -f "$REQUIREMENTS_LOCK" ]]; then 37 | echo "Error: Requirements lock file not found at $REQUIREMENTS_LOCK" >&2 38 | exit 1 39 | fi 40 | 41 | echo "Installing dependencies from $REQUIREMENTS_LOCK..." 42 | 43 | # Check if UV is installed 44 | if ! command -v uv &> /dev/null; then 45 | echo "Error: UV is not installed. Please install it first:" 46 | echo " pip install uv" 47 | echo "or" 48 | echo " curl -sSf https://install.uv.dev | python3 -" 49 | exit 1 50 | fi 51 | 52 | # Create virtual environment if it doesn't exist 53 | VENV_DIR="$PROJECT_ROOT/.venv" 54 | if [[ ! -d "$VENV_DIR" ]]; then 55 | echo "Creating virtual environment..." 56 | if ! uv venv "$VENV_DIR"; then 57 | echo "Error creating virtual environment" >&2 58 | exit 1 59 | fi 60 | fi 61 | 62 | # Install dependencies in virtual environment 63 | if ! uv pip install -r "$REQUIREMENTS_LOCK"; then 64 | echo "Error installing dependencies" >&2 65 | exit 1 66 | fi 67 | fi 68 | 69 | # Set essential environment variables before running tests 70 | export BASE_BRANCH="main" 71 | export CONTRAST_HOST="test.contrastsecurity.com" 72 | export CONTRAST_ORG_ID="test-org-id" 73 | export CONTRAST_APP_ID="test-app-id" 74 | export CONTRAST_AUTHORIZATION_KEY="test-auth-key" 75 | export CONTRAST_API_KEY="test-api-key" 76 | export GITHUB_TOKEN="mock-github-token" 77 | export GITHUB_REPOSITORY="mock/repo" 78 | export GITHUB_SERVER_URL="https://mockhub.com" 79 | export GITHUB_EVENT_PATH="/tmp/github_event.json" 80 | export GITHUB_WORKSPACE="/tmp" 81 | export REPO_ROOT="/tmp/test_repo" 82 | export BUILD_COMMAND="echo 'test build command'" 83 | export FORMATTING_COMMAND="echo 'test format command'" 84 | export DEBUG_MODE="true" 85 | export TESTING="true" 86 | export ENABLE_ANTHROPIC_PROMPT_CACHING="true" 87 | 88 | # Run tests 89 | VENV_DIR="$PROJECT_ROOT/.venv" 90 | PYTHON_CMD="python" 91 | if [[ -d "$VENV_DIR" ]]; then 92 | PYTHON_CMD="$VENV_DIR/bin/python" 93 | fi 94 | 95 | if [[ ${#TEST_FILES[@]} -eq 0 ]]; then 96 | echo "Running all tests..." 97 | "$PYTHON_CMD" -m unittest discover -s test 98 | else 99 | echo "Running specific tests: ${TEST_FILES[*]}" 100 | "$PYTHON_CMD" -m unittest "${TEST_FILES[@]}" 101 | fi 102 | -------------------------------------------------------------------------------- /src/smartfix/__init__.py: -------------------------------------------------------------------------------- 1 | """Contrast AI SmartFix Library 2 | 3 | This library provides reusable SmartFix functionality for vulnerability remediation 4 | that is independent of specific SCM providers and deployment environments. 5 | 6 | The library is organized into domain-driven modules: 7 | - domains.vulnerability: Vulnerability processing and remediation logic 8 | - domains.agents: AI agent orchestration and management 9 | - domains.analysis: Code analysis and build management 10 | - domains.scm: Source control management abstractions 11 | - domains.integrations: External service integrations 12 | - domains.workflow: Workflow orchestration and execution 13 | - extensions: Enhanced LLM integrations and extensions 14 | - config: Configuration management and dependency injection 15 | - telemetry: Telemetry collection and cost tracking 16 | - shared: Shared utilities and common functionality 17 | 18 | Example usage: 19 | ```python 20 | from smartfix import SmartFixAgent, SmartFixConfig 21 | from smartfix.extensions import SmartFixLiteLlm, SmartFixLlmAgent 22 | 23 | # Create configuration 24 | config = SmartFixConfig( 25 | llm_provider="anthropic", 26 | model="claude-3-sonnet-20240229" 27 | ) 28 | 29 | # Create agent 30 | agent = SmartFixAgent(config=config) 31 | 32 | # Process vulnerability 33 | result = agent.process_vulnerability( 34 | repo_path="/path/to/repo", 35 | vulnerability_data=vuln_data, 36 | prompts=prompts 37 | ) 38 | ``` 39 | 40 | For testing harness integration: 41 | ```python 42 | # Import the core agent for batch processing 43 | from smartfix import SmartFixAgent 44 | from smartfix.config import SmartFixConfig 45 | 46 | # Create agent in testing mode 47 | config = SmartFixConfig(testing_mode=True) 48 | agent = SmartFixAgent(config=config) 49 | 50 | # Batch process multiple vulnerabilities 51 | results = [] 52 | for vuln in vulnerabilities: 53 | result = agent.process_vulnerability( 54 | repo_path=test_repo_path, 55 | vulnerability_data=vuln, 56 | prompts=prompts 57 | ) 58 | results.append(result) 59 | ``` 60 | """ 61 | 62 | __version__ = "1.0.0-dev" 63 | __author__ = "Contrast Security" 64 | __description__ = "AI-powered vulnerability remediation library" 65 | 66 | # Core library exports - these will be available as components are implemented 67 | __all__ = [ 68 | # Metadata 69 | "__version__", 70 | "__author__", 71 | "__description__", 72 | 73 | # Core interfaces (to be implemented in later tasks) 74 | # "SmartFixAgent", # Main agent interface 75 | # "SmartFixConfig", # Configuration management 76 | # "CodingAgentStrategy", # Agent strategy interface 77 | # "RemediationWorkflow", # Workflow orchestration 78 | # "RemediationContext", # Vulnerability context 79 | # "BuildResult", # Build execution results 80 | # "ScmProvider", # SCM abstraction interface 81 | ] 82 | 83 | # Enhanced LLM extensions are already available 84 | try: 85 | from .extensions import SmartFixLiteLlm, SmartFixLlmAgent # noqa: F401 86 | __all__.extend(["SmartFixLiteLlm", "SmartFixLlmAgent"]) 87 | except ImportError: 88 | # Extensions may not be available in all environments 89 | pass 90 | 91 | # Future imports will be added as components are implemented: 92 | # from .domains.agents import SmartFixAgent, CodingAgentStrategy 93 | # from .domains.workflow import RemediationWorkflow 94 | # from .domains.vulnerability import RemediationContext 95 | # from .domains.analysis import BuildResult 96 | # from .domains.scm import ScmProvider 97 | # from .config import SmartFixConfig 98 | -------------------------------------------------------------------------------- /test/test_utils_normalize_host.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | 23 | # Test setup imports (path is set up by conftest.py) 24 | from src.utils import normalize_host 25 | 26 | 27 | class TestNormalizeHost(unittest.TestCase): 28 | """Tests for the normalize_host function in utils.py""" 29 | 30 | def test_normalize_host_with_https(self): 31 | """Test that https:// prefix is removed""" 32 | input_host = "https://example.com" 33 | expected = "example.com" 34 | result = normalize_host(input_host) 35 | self.assertEqual(result, expected) 36 | 37 | def test_normalize_host_with_http(self): 38 | """Test that http:// prefix is removed""" 39 | input_host = "http://example.com" 40 | expected = "example.com" 41 | result = normalize_host(input_host) 42 | self.assertEqual(result, expected) 43 | 44 | def test_normalize_host_without_protocol(self): 45 | """Test that host without protocol is unchanged""" 46 | input_host = "example.com" 47 | expected = "example.com" 48 | result = normalize_host(input_host) 49 | self.assertEqual(result, expected) 50 | 51 | def test_normalize_host_with_port(self): 52 | """Test that host with port works correctly""" 53 | input_host = "https://example.com:8080" 54 | expected = "example.com:8080" 55 | result = normalize_host(input_host) 56 | self.assertEqual(result, expected) 57 | 58 | def test_normalize_host_with_path(self): 59 | """Test that host with path works correctly""" 60 | input_host = "https://example.com/api/v1" 61 | expected = "example.com/api/v1" 62 | result = normalize_host(input_host) 63 | self.assertEqual(result, expected) 64 | 65 | def test_normalize_host_with_subdomain(self): 66 | """Test that host with subdomain works correctly""" 67 | input_host = "https://api.example.com" 68 | expected = "api.example.com" 69 | result = normalize_host(input_host) 70 | self.assertEqual(result, expected) 71 | 72 | def test_normalize_host_multiple_protocols(self): 73 | """Test that multiple protocol prefixes are handled correctly""" 74 | input_host = "http://https://example.com" 75 | expected = "example.com" 76 | result = normalize_host(input_host) 77 | self.assertEqual(result, expected) 78 | 79 | def test_normalize_host_empty_string(self): 80 | """Test that empty string is handled correctly""" 81 | input_host = "" 82 | expected = "" 83 | result = normalize_host(input_host) 84 | self.assertEqual(result, expected) 85 | 86 | def test_normalize_host_protocol_in_middle(self): 87 | """Test that protocol in middle of string is also removed""" 88 | input_host = "example.com/https://path" 89 | expected = "example.com/path" 90 | result = normalize_host(input_host) 91 | self.assertEqual(result, expected) 92 | 93 | 94 | if __name__ == '__main__': 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /src/build_output_analyzer.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | def extract_build_errors(build_output): 21 | """ 22 | Extract the most relevant error information from build output. 23 | 24 | This function captures error blocks with context before and after errors, 25 | and intelligently extends blocks when errors are found in sequence. 26 | 27 | Args: 28 | build_output: The complete output from the build command 29 | 30 | Returns: 31 | str: A condensed report of the most relevant error regions 32 | """ 33 | # If output is small enough, just return it all 34 | if len(build_output) < 2000: 35 | return build_output 36 | 37 | lines = build_output.splitlines() 38 | 39 | # Look at the last part of the output (where errors typically appear) 40 | tail_lines = lines[-500:] if len(lines) > 500 else lines 41 | 42 | # Common error indicators across build systems 43 | error_indicators = ["error", "exception", "failed", "failure", "fatal"] 44 | 45 | # Process the lines to find error regions with their context 46 | context_size = 5 # Number of lines to include before an error 47 | error_regions = [] # Will hold start and end indices of error regions 48 | 49 | # First pass: identify all error lines 50 | error_line_indices = [] 51 | for i, line in enumerate(tail_lines): 52 | line_lower = line.lower() 53 | if any(indicator in line_lower for indicator in error_indicators): 54 | error_line_indices.append(i) 55 | 56 | # Second pass: merge nearby errors into regions 57 | if error_line_indices: 58 | current_region_start = max(0, error_line_indices[0] - context_size) 59 | current_region_end = error_line_indices[0] + context_size 60 | 61 | for idx in error_line_indices[1:]: 62 | # If this error is within or close to current region, extend the region 63 | if idx - context_size <= current_region_end + 2: # Allow small gaps 64 | current_region_end = idx + context_size 65 | else: 66 | # This error is far from the previous region, save current region 67 | # and start a new one 68 | error_regions.append((current_region_start, min(current_region_end, len(tail_lines) - 1))) 69 | current_region_start = max(0, idx - context_size) 70 | current_region_end = idx + context_size 71 | 72 | # Don't forget the last region 73 | error_regions.append((current_region_start, min(current_region_end, len(tail_lines) - 1))) 74 | 75 | # Extract the text from each error region 76 | error_blocks = [] 77 | for start, end in error_regions: 78 | region_lines = tail_lines[start:end + 1] 79 | error_blocks.append("\n".join(region_lines)) 80 | 81 | # If we found error blocks, return them (up to 3 most recent) 82 | if error_blocks: 83 | result_blocks = error_blocks[-3:] if len(error_blocks) > 3 else error_blocks 84 | return "BUILD FAILURE - KEY ERRORS:\n\n" + "\n\n...\n\n".join(result_blocks) 85 | 86 | # Fallback: just return the last part of the build output 87 | return "BUILD FAILURE - LAST OUTPUT:\n\n" + "\n".join(tail_lines[-50:]) 88 | -------------------------------------------------------------------------------- /src/smartfix/domains/vulnerability/models.py: -------------------------------------------------------------------------------- 1 | """Vulnerability Data Models 2 | 3 | This module contains the core vulnerability data models that represent 4 | security vulnerabilities based on the actual API structure used by the system. 5 | 6 | Simplified to match the real-world usage in main.py and contrast_api.py. 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from enum import Enum 11 | from typing import Optional, Dict, Any 12 | 13 | 14 | class VulnerabilitySeverity(Enum): 15 | """Enumeration of vulnerability severity levels.""" 16 | CRITICAL = "CRITICAL" 17 | HIGH = "HIGH" 18 | MEDIUM = "MEDIUM" 19 | LOW = "LOW" 20 | NOTE = "NOTE" 21 | 22 | 23 | @dataclass 24 | class Vulnerability: 25 | """ 26 | Domain model representing a security vulnerability. 27 | 28 | Simplified to match actual API data structure and system usage. 29 | Based on the actual fields received from Contrast API. 30 | 31 | Note: remediation_id is tracked in RemediationContext, not here. 32 | """ 33 | uuid: str 34 | title: str 35 | rule_name: str 36 | severity: VulnerabilitySeverity 37 | description: Optional[str] = None 38 | cwe_id: Optional[str] = None 39 | metadata: Optional[Dict[str, Any]] = None 40 | 41 | def __post_init__(self): 42 | """Validate after initialization.""" 43 | self._validate_required_fields() 44 | if self.metadata is None: 45 | self.metadata = {} 46 | 47 | def _validate_required_fields(self) -> None: 48 | """ 49 | Validate that all required fields are present and valid. 50 | 51 | Raises: 52 | ValueError: If validation fails 53 | """ 54 | if not self.uuid: 55 | raise ValueError("Vulnerability UUID is required") 56 | if not self.title: 57 | raise ValueError("Vulnerability title is required") 58 | if not self.rule_name: 59 | raise ValueError("Vulnerability rule name is required") 60 | 61 | @classmethod 62 | def from_api_data(cls, api_data: Dict[str, Any]) -> "Vulnerability": 63 | """ 64 | Create a Vulnerability instance from Contrast API response data. 65 | 66 | Args: 67 | api_data: Dictionary containing vulnerability data from API 68 | 69 | Returns: 70 | Vulnerability instance 71 | 72 | Raises: 73 | ValueError: If required fields are missing or invalid 74 | KeyError: If expected API keys are not present 75 | """ 76 | # Map API severity strings to enum values 77 | severity_mapping = { 78 | 'CRITICAL': VulnerabilitySeverity.CRITICAL, 79 | 'HIGH': VulnerabilitySeverity.HIGH, 80 | 'MEDIUM': VulnerabilitySeverity.MEDIUM, 81 | 'LOW': VulnerabilitySeverity.LOW, 82 | 'NOTE': VulnerabilitySeverity.NOTE, 83 | } 84 | 85 | # Extract required fields 86 | uuid = api_data['vulnerabilityUuid'] 87 | title = api_data['vulnerabilityTitle'] 88 | rule_name = api_data['vulnerabilityRuleName'] 89 | 90 | # Map severity 91 | severity_str = api_data['vulnerabilitySeverity'] 92 | severity = severity_mapping.get(severity_str, VulnerabilitySeverity.MEDIUM) 93 | 94 | return cls( 95 | uuid=uuid, 96 | title=title, 97 | rule_name=rule_name, 98 | severity=severity, 99 | description=api_data.get('description'), 100 | cwe_id=api_data.get('cweId'), 101 | metadata=api_data.get('metadata', {}) 102 | ) 103 | 104 | def to_dict(self) -> Dict[str, Any]: 105 | """ 106 | Convert vulnerability to dictionary representation. 107 | 108 | Returns: 109 | Dictionary containing vulnerability data 110 | """ 111 | return { 112 | 'uuid': self.uuid, 113 | 'title': self.title, 114 | 'rule_name': self.rule_name, 115 | 'severity': self.severity.value, 116 | 'description': self.description, 117 | 'cwe_id': self.cwe_id, 118 | 'metadata': self.metadata 119 | } 120 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | ### IntelliJ IDEA ### 36 | .idea 37 | *.iws 38 | *.iml 39 | *.ipr 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # UV 104 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | #uv.lock 108 | 109 | # poetry 110 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 111 | # This is especially recommended for binary packages to ensure reproducibility, and is more 112 | # commonly ignored for libraries. 113 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 114 | #poetry.lock 115 | 116 | # pdm 117 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 118 | #pdm.lock 119 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 120 | # in version control. 121 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 122 | .pdm.toml 123 | .pdm-python 124 | .pdm-build/ 125 | 126 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 127 | __pypackages__/ 128 | 129 | # Celery stuff 130 | celerybeat-schedule 131 | celerybeat.pid 132 | 133 | # SageMath parsed files 134 | *.sage.py 135 | 136 | # Environments 137 | .env 138 | .venv 139 | env/ 140 | venv/ 141 | ENV/ 142 | env.bak/ 143 | venv.bak/ 144 | 145 | # Spyder project settings 146 | .spyderproject 147 | .spyproject 148 | 149 | # Rope project settings 150 | .ropeproject 151 | 152 | # mkdocs documentation 153 | /site 154 | 155 | # mypy 156 | .mypy_cache/ 157 | .dmypy.json 158 | dmypy.json 159 | 160 | # Pyre type checker 161 | .pyre/ 162 | 163 | # pytype static type analyzer 164 | .pytype/ 165 | 166 | # Cython debug symbols 167 | cython_debug/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | #.idea/ 175 | 176 | # Ruff stuff: 177 | .ruff_cache/ 178 | 179 | # PyPI configuration file 180 | .pypirc 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Contrast AI SmartFix \- User Documentation 2 | 3 | ## Legal Disclaimer 4 | 5 | When you use Contrast AI SmartFix, you agree that your code and other data will be submitted to an LLM of your choice. Both the submission of data to the LLM and the output generated by the LLM will be subject to the terms of service of that LLM. Use of Contrast AI SmartFix is entirely at your own risk. 6 | 7 | ## Introduction 8 | 9 | Welcome to Contrast AI SmartFix\! SmartFix is an AI-powered agent that automatically generates code fixes for vulnerabilities identified by Contrast Assess. It integrates into your developer workflow via GitHub Actions, creating Pull Requests (PRs) with proposed remediations. 10 | 11 | **Key Benefits:** 12 | 13 | * **Automated Remediation:** Reduces the manual effort and time required to fix vulnerabilities. 14 | * **Developer-Focused:** Delivers fixes as PRs directly in your GitHub repository, fitting naturally into existing workflows. 15 | * **Runtime Context:** Leverages Contrast Assess's runtime analysis (IAST) to provide more accurate and relevant fixes. 16 | 17 | ## Getting Started 18 | 19 | ### Coding Agent 20 | 21 | SmartFix supports three distinct coding agents for vulnerability remediation on GitHub: 22 | 23 | * **SmartFix Agent (Recommended):** Uses Contrast vulnerability data with a team of agentic AIs to analyze, fix, and validate vulnerability remediations. This agent creates a complete fix, ensures your project builds successfully, and verifies that existing tests continue to pass. Use Contrast LLM, or you can bring your own LLM provider. 24 | 25 | * **GitHub Copilot Agent:** Leverages GitHub Copilot for vulnerability fixes through GitHub Issues. SmartFix creates a detailed GitHub Issue with vulnerability information and assigns it to GitHub Copilot for resolution. Copilot then attempts the fix and creates a Pull Request. Requires your repository to enable GitHub Issues and Copilot. 26 | 27 | * **Claude Code Agent:** Leverages Anthropic's Claude Code bot for vulnerability fixes through GitHub Issues. SmartFix creates a detailed GitHub Issue with vulnerability information and mentions the Claude Code bot in the Issue title. Claude Code then attempts the fix and creates a Pull Request. Requires your repository to install the Claude Code GitHub App. 28 | 29 | Please follow the specific setup instructions link for the coding agent of your choice: 30 | * [SmartFix Coding Agent](https://github.com/Contrast-Security-OSS/contrast-ai-smartfix-action/blob/main/docs/smartfix_coding_agent.md) 31 | * [GitHub Copilot](https://github.com/Contrast-Security-OSS/contrast-ai-smartfix-action/blob/main/docs/github_copilot.md) 32 | * [Claude Code](https://github.com/Contrast-Security-OSS/contrast-ai-smartfix-action/blob/main/docs/claude_code.md) 33 | 34 | ## Key Features 35 | 36 | * **Support for Multiple Coding Agents**: Choose to use either the internal SmartFix coding agent or GitHub Copilot to remediate your project's vulnerabilities 37 | * **Contrast LLM or Bring Your Own LLM (BYOLLM):** For the SmartFix Coding Agent, use Contrast's managed LLM service for seamless setup with enterprise-grade security, or for advanced users who prefer their own LLM provider use the optional Bring Your Own LLM (BYOLLM) configuration. [Learn more about Contrast LLM](docs/contrast_llm_early_access.md). Note: Contrast LLM is only available with the SmartFix Coding Agent. 38 | * **Configurable PR Throttling:** Control the volume of automated PRs using `max_open_prs`. 39 | * **Debug Mode:** Enable `debug_mode: 'true'` for verbose logging in the GitHub Action output. 40 | 41 | ## FAQ 42 | 43 | * **Q: Can I use SmartFix if I don't use Contrast Assess?** 44 | * A: No, SmartFix relies on vulnerability data from Contrast Assess. In the future we plan to expand to include more. 45 | * **Q: How often does SmartFix run?** 46 | * A: This is determined by the `schedule` trigger in your GitHub Actions workflow file. You can customize it. 47 | * **Q: What happens if the AI cannot generate a fix?** 48 | * A: The agent will log this, and no PR will be created for that specific vulnerability attempt. It will retry on a future run. 49 | * **Q: Can SmartFix fix multiple vulnerabilities in one PR?** 50 | * A: No, each PR addresses a single vulnerability. 51 | * **Q: Will SmartFix add new library dependencies?** 52 | * A: Generally, SmartFix aims to use existing libraries and frameworks. We have instructed it not to make major architectural changes or add new dependencies. 53 | 54 | --- 55 | 56 | For further assistance or to provide feedback on SmartFix, please contact your Contrast Security representative. 57 | -------------------------------------------------------------------------------- /.github/copilot-instructions.md: -------------------------------------------------------------------------------- 1 | # GitHub Copilot Instructions for Contrast AI SmartFix 2 | 3 | ## FUNDAMENTAL RULE: Test-First Simplicity 4 | 5 | **EVERY change follows this pattern:** 6 | 1. **Write tests FIRST** - Before writing any production code 7 | 2. **Question every requirement** - Is this actually needed? (*Exception: When executing planned refactoring tasks from ai-dev/tasks.md*) 8 | 3. **Delete unnecessary parts** - Remove complexity, don't add it 9 | 4. **Make tests pass** - Simplest implementation that works 10 | 5. **Look around and refactor downstream** - Fix related code that uses your changes 11 | 12 | ## Working with Planned vs. Unplanned Work 13 | 14 | ### Planned Tasks (ai-dev/tasks.md): 15 | - Follow the task plan - requirements pre-analyzed 16 | - Test each task's deliverables before implementing 17 | - Keep implementations simple even if architecture is complex 18 | - Question implementation details, not the overall plan 19 | 20 | ### Unplanned Work: 21 | - Question everything - apply full YAGNI principles 22 | - Start minimal - build only what's needed right now 23 | - Delete unused code immediately 24 | 25 | ## Core Workflow 26 | 27 | ### Testing (DO THIS FIRST) 28 | - Write tests before production code - no exceptions 29 | - Run `./test/run_tests.sh` after every change 30 | - All tests must pass before moving on 31 | - Delete tests for deleted features 32 | 33 | ### Complete Refactoring (Not Partial) 34 | - When you touch code, look around - what else uses this? 35 | - Follow the dependency chain - update ALL callers 36 | - Clean up as you go - remove dead imports, unused variables 37 | 38 | ### Anti-Over-Engineering 39 | - YAGNI - don't build for imaginary future requirements 40 | - Question every class/method - can this be a simple function? 41 | - Delete unused code immediately 42 | - Favor simple functions over complex class hierarchies 43 | 44 | ### Red Flags - Stop and Simplify 45 | - Classes with 1 method → use a function 46 | - Enums/constants nobody uses → delete them 47 | - More than 3 levels of abstraction → flatten it 48 | - Configuration for imaginary features → remove it 49 | 50 | ## Code Standards 51 | - All changes must pass flake8 linting 52 | - Type hints for public interfaces 53 | - Max line length: 180 characters 54 | - Fix whitespace: `sed -i '' 's/[[:space:]]*$//' path/to/file.py` 55 | 56 | ## File Management - CLEAN UP INTERMEDIATE FILES 57 | - **Never leave duplicate files** - If you create `file_clean.py`, `file_backup.py`, `file_new.py` while editing, DELETE the extras 58 | - **Check for duplicates before finishing** - Run `ls -la` to spot files with similar names 59 | - **One source of truth** - Each logical unit should have exactly one file 60 | - **Remove failed attempts** - If you mess up editing and start over, delete the corrupted version 61 | - **Common duplicate patterns to avoid:** 62 | - `test_something.py` and `test_something_clean.py` 63 | - `module.py` and `module_backup.py` 64 | - `config.py` and `config_new.py` 65 | 66 | ## When Refactoring Existing Code 67 | 1. **Write tests for current behavior first** - Capture existing functionality 68 | 2. **Check for planned migration path** - Is there a tasks.md plan for this area? 69 | 3. **For planned refactoring**: Follow the task sequence and dependencies 70 | 4. **For unplanned refactoring**: Question what can be deleted - Unused code, over-complex abstractions 71 | 5. **Refactor incrementally** - Small, safe changes with tests between each 72 | 6. **Follow the chain** - Update ALL code that depends on your changes 73 | 7. **Clean up afterwards** - Remove dead imports, fix related issues 74 | 75 | ## Additional Red Flags - Stop and Simplify (Even in Planned Work) 76 | - **Complex class hierarchies** - Can the implementation be simpler? 77 | - **More than 3 levels of abstraction** - Flatten the implementation 78 | - **Circular dependencies** - Break them up during implementation 79 | - **Classes with 1 method** - Could this be a function instead? 80 | - **Enums or constants nobody uses** - Delete them even if planned 81 | - **Configuration for imaginary future features** - Remove it 82 | 83 | ## Balancing Planning vs. Simplicity 84 | - **Respect the architecture plan** - Don't change overall structure during task execution 85 | - **Keep implementations simple** - Complex architecture ≠ complex implementation 86 | - **Suggest alternatives** - If you see over-engineering opportunities during implementation 87 | - **Focus on YAGNI for implementation details** - Don't add features not required by the current task 88 | - **Test thoroughly** - Especially important when following complex architectural plans 89 | -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import contextlib 4 | import unittest 5 | import tempfile 6 | from unittest.mock import patch, MagicMock 7 | 8 | # Test setup imports (path is set up by conftest.py) 9 | from src.config import reset_config 10 | from src.main import main 11 | 12 | 13 | class TestSmartFixAction(unittest.TestCase): 14 | 15 | def setUp(self): 16 | # Create a temporary directory for HOME to fix git config issues 17 | self.temp_home = tempfile.mkdtemp() 18 | 19 | # Set up mock environment variables for testing 20 | self.env_patcher = patch.dict('os.environ', { 21 | 'HOME': self.temp_home, # Set HOME for git config 22 | 'GITHUB_WORKSPACE': self.temp_home, # Required by config.py 23 | 'BUILD_COMMAND': 'echo "Mock build command"', 24 | 'FORMATTING_COMMAND': 'echo "Mock formatting command"', 25 | 'GITHUB_TOKEN': 'mock-github-token', 26 | 'GITHUB_REPOSITORY': 'mock/repository', 27 | 'GITHUB_SERVER_URL': 'https://mockhub.com', 28 | 'CONTRAST_HOST': 'mock.contrastsecurity.com', 29 | 'CONTRAST_ORG_ID': 'mock-org-id', 30 | 'CONTRAST_APP_ID': 'mock-app-id', 31 | 'CONTRAST_AUTHORIZATION_KEY': 'mock-auth-key', 32 | 'CONTRAST_API_KEY': 'mock-api-key', 33 | 'BASE_BRANCH': 'main', 34 | 'DEBUG_MODE': 'true', 35 | 'RUN_TASK': 'generate_fix' # Add RUN_TASK to prevent missing env var errors 36 | }) 37 | self.env_patcher.start() 38 | 39 | # Mock subprocess to prevent actual command execution 40 | self.subprocess_patcher = patch('subprocess.run') 41 | self.mock_subprocess_run = self.subprocess_patcher.start() 42 | mock_process = MagicMock() 43 | mock_process.returncode = 0 44 | mock_process.stdout = "Mock process output" 45 | mock_process.stderr = "" 46 | mock_process.communicate.return_value = (b"Mock stdout", b"Mock stderr") 47 | self.mock_subprocess_run.return_value = mock_process 48 | 49 | # Mock git_handler's configure_git_user to prevent git config errors 50 | self.git_config_patcher = patch('src.git_handler.configure_git_user') 51 | self.mock_git_config = self.git_config_patcher.start() 52 | 53 | # Mock API calls to prevent network issues 54 | self.api_patcher = patch('src.contrast_api.get_vulnerability_with_prompts') 55 | self.mock_api = self.api_patcher.start() 56 | self.mock_api.return_value = None # No vulnerabilities by default 57 | 58 | # Mock all HTTP requests 59 | self.requests_patcher = patch('requests.post') 60 | self.mock_requests_post = self.requests_patcher.start() 61 | mock_post_response = MagicMock() 62 | mock_post_response.status_code = 404 # Not found, to avoid further processing 63 | self.mock_requests_post.return_value = mock_post_response 64 | 65 | # Mock version check requests 66 | self.version_requests_patcher = patch('src.version_check.requests.get') 67 | self.mock_requests_get = self.version_requests_patcher.start() 68 | mock_response = MagicMock() 69 | mock_response.json.return_value = [{'name': 'v1.0.0'}] 70 | mock_response.raise_for_status.return_value = None 71 | self.mock_requests_get.return_value = mock_response 72 | 73 | # Mock sys.exit to prevent test termination 74 | self.exit_patcher = patch('sys.exit') 75 | self.mock_exit = self.exit_patcher.start() 76 | 77 | def tearDown(self): 78 | # Clean up all patches 79 | self.env_patcher.stop() 80 | self.subprocess_patcher.stop() 81 | self.git_config_patcher.stop() 82 | self.api_patcher.stop() 83 | self.requests_patcher.stop() 84 | self.version_requests_patcher.stop() 85 | self.exit_patcher.stop() 86 | reset_config() 87 | 88 | # Clean up temp directory if it exists 89 | if hasattr(self, 'temp_home') and os.path.exists(self.temp_home): 90 | import shutil 91 | try: 92 | shutil.rmtree(self.temp_home) 93 | except Exception: 94 | pass 95 | 96 | def test_main_output(self): 97 | # Test main function output 98 | with io.StringIO() as stdout, contextlib.redirect_stdout(stdout): 99 | main() 100 | output = stdout.getvalue().strip() 101 | self.assertIn("--- Starting Contrast AI SmartFix Script ---", output) 102 | 103 | 104 | if __name__ == '__main__': 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /.beads/issues.jsonl: -------------------------------------------------------------------------------- 1 | {"id":"contrast-ai-smartfix-action-8re","title":"Integrate validation into Config class","description":"Integrate command validation into src/config.py Config.__init__():\n- Import command_validator module\n- Add _validate_command() method to Config class\n- Call validation for BUILD_COMMAND after reading from environment (line ~82)\n- Call validation for FORMATTING_COMMAND after reading from environment (line ~84)\n- Ensure CommandValidationError is caught and converted to ConfigurationError\n- Follow existing config validation patterns (similar to _check_contrast_config_values_exist)\n- Maintain fail-fast behavior with clear error messages\n\nRelated to JIRA: AIML-337","status":"closed","priority":1,"issue_type":"task","created_at":"2025-12-23T09:50:37.859629-05:00","updated_at":"2025-12-23T10:34:27.5546-05:00","closed_at":"2025-12-23T10:34:27.5546-05:00","dependencies":[{"issue_id":"contrast-ai-smartfix-action-8re","depends_on_id":"contrast-ai-smartfix-action-gad","type":"blocks","created_at":"2025-12-23T09:51:11.594923-05:00","created_by":"daemon"}]} 2 | {"id":"contrast-ai-smartfix-action-gad","title":"Create command validation module","description":"Create src/smartfix/config/command_validator.py with:\n- ALLOWED_COMMANDS constant (full allowlist for Java, .NET, Python, PHP, NodeJS, build tools, shell utilities)\n- BLOCKED_PATTERNS regex list (command substitution, eval, exec, dangerous operations)\n- validate_shell_command() - ensures sh/bash only execute .sh files, blocks -c flag\n- validate_redirect() - allows relative paths only, blocks absolute paths and .. traversal\n- contains_dangerous_patterns() - checks command against blocked patterns\n- split_command_chain() - parses commands by operators (\u0026\u0026, ||, ;, |)\n- parse_command_segment() - extracts executable and arguments from command segment\n- extract_redirects() - finds file redirects in command\n- validate_command() - main validation function that orchestrates all checks\n- CommandValidationError exception class\n\nRelated to JIRA: AIML-337","status":"closed","priority":1,"issue_type":"task","created_at":"2025-12-23T09:49:50.169966-05:00","updated_at":"2025-12-23T10:28:03.234049-05:00","closed_at":"2025-12-23T10:28:03.234049-05:00"} 3 | {"id":"contrast-ai-smartfix-action-q9y","title":"Create comprehensive test suite for command validation","description":"Create test/test_command_validation.py with comprehensive test coverage:\n- Test all allowed commands validate successfully (Java, .NET, Python, PHP, NodeJS, build tools)\n- Test blocked executables are rejected with proper error messages\n- Test command chaining with operators (\u0026\u0026, ||, ;, |)\n- Test shell script execution: sh/bash with .sh files allowed, -c flag blocked\n- Test redirect validation: relative paths pass, absolute paths and .. traversal fail\n- Test dangerous pattern detection: command substitution, eval, exec, piping to shell\n- Test edge cases: empty commands, whitespace, special characters, complex chains\n- Test error messages are clear and actionable\n- Ensure all tests pass and maintain existing test coverage (165+ tests)\n\nRelated to JIRA: AIML-337","status":"closed","priority":2,"issue_type":"task","created_at":"2025-12-23T09:51:20.598647-05:00","updated_at":"2025-12-23T10:35:10.53909-05:00","closed_at":"2025-12-23T10:35:10.53909-05:00","dependencies":[{"issue_id":"contrast-ai-smartfix-action-q9y","depends_on_id":"contrast-ai-smartfix-action-gad","type":"blocks","created_at":"2025-12-23T09:51:56.723262-05:00","created_by":"daemon"},{"issue_id":"contrast-ai-smartfix-action-q9y","depends_on_id":"contrast-ai-smartfix-action-8re","type":"blocks","created_at":"2025-12-23T09:51:56.843412-05:00","created_by":"daemon"}]} 4 | {"id":"contrast-ai-smartfix-action-wuv","title":"Update documentation for command allowlist","description":"Update README.md and related documentation:\n- Add new section explaining command allowlist security feature\n- Document all allowed commands organized by language/ecosystem\n- Document allowed operators for chaining (\u0026\u0026, ||, ;, |)\n- Document allowed redirect patterns (relative paths only)\n- Provide examples of valid commands for each language\n- Provide examples of blocked commands and why they fail\n- Document error messages users might encounter\n- Add troubleshooting section for common validation failures\n- Update security.md if present to mention this security feature\n\nRelated to JIRA: AIML-337","status":"closed","priority":3,"issue_type":"task","created_at":"2025-12-23T09:52:22.925762-05:00","updated_at":"2025-12-23T10:40:22.502328-05:00","closed_at":"2025-12-23T10:40:22.502328-05:00","dependencies":[{"issue_id":"contrast-ai-smartfix-action-wuv","depends_on_id":"contrast-ai-smartfix-action-q9y","type":"blocks","created_at":"2025-12-23T09:52:26.450395-05:00","created_by":"daemon"}]} 5 | -------------------------------------------------------------------------------- /src/smartfix/domains/workflow/credit_tracking.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | """Credit tracking data models and utilities for Contrast LLM usage.""" 21 | 22 | from dataclasses import dataclass 23 | from datetime import datetime 24 | 25 | 26 | @dataclass 27 | class CreditTrackingResponse: 28 | """Response model for credit tracking API.""" 29 | organization_id: str 30 | enabled: bool 31 | max_credits: int 32 | credits_used: int 33 | start_date: str 34 | end_date: str 35 | 36 | @property 37 | def credits_remaining(self) -> int: 38 | """Calculate remaining credits.""" 39 | return self.max_credits - self.credits_used 40 | 41 | @property 42 | def is_exhausted(self) -> bool: 43 | """Check if credits are exhausted.""" 44 | return self.credits_remaining <= 0 45 | 46 | @property 47 | def is_low(self) -> bool: 48 | """Check if credits are running low (5 or fewer remaining).""" 49 | return self.credits_remaining <= 5 and self.credits_remaining > 0 50 | 51 | def _format_timestamp(self, iso_timestamp: str) -> str: 52 | """Format ISO timestamp to human-readable format.""" 53 | if not iso_timestamp: 54 | return "Unknown" 55 | 56 | try: 57 | # Parse ISO format timestamp (e.g., "2025-10-30T01:00:00Z") 58 | dt = datetime.fromisoformat(iso_timestamp.replace('Z', '+00:00')) 59 | # Format as "Oct 30, 2025" 60 | return dt.strftime("%b %d, %Y") 61 | except (ValueError, AttributeError): 62 | # If parsing fails, return the original timestamp 63 | return iso_timestamp 64 | 65 | def to_log_message(self) -> str: 66 | """Format credit information for log output.""" 67 | if not self.enabled: 68 | return "Credit tracking is disabled for this organization" 69 | 70 | return (f"Credits: {self.credits_used}/{self.max_credits} used " 71 | f"({self.credits_remaining} remaining). Trial expires {self.end_date}") 72 | 73 | def get_credit_warning_message(self) -> str: 74 | """Get warning message for credit status, with color formatting.""" 75 | if self.is_exhausted: 76 | return "Credits have been exhausted. Contact your CSM to request additional credits." 77 | elif self.is_low: 78 | # Yellow text formatting for low credits warning 79 | return f"\033[0;33m{self.credits_remaining} credits remaining \033[0m" 80 | return "" 81 | 82 | def should_log_warning(self) -> bool: 83 | """Check if a warning should be logged.""" 84 | return self.is_exhausted or self.is_low 85 | 86 | def to_pr_body_section(self) -> str: 87 | """Format credit information for PR body append.""" 88 | if not self.enabled: 89 | return "" 90 | 91 | start_formatted = self._format_timestamp(self.start_date) 92 | end_formatted = self._format_timestamp(self.end_date) 93 | 94 | return f""" 95 | --- 96 | ### Contrast LLM Credits 97 | - **Used:** {self.credits_used}/{self.max_credits} 98 | - **Remaining:** {self.credits_remaining} 99 | - **Trial Period:** {start_formatted} to {end_formatted} 100 | """ 101 | 102 | @classmethod 103 | def from_api_response(cls, response_data: dict) -> 'CreditTrackingResponse': 104 | """Create instance from API response data.""" 105 | return cls( 106 | organization_id=response_data['organizationId'], 107 | enabled=response_data['enabled'], 108 | max_credits=response_data['maxCredits'], 109 | credits_used=response_data['creditsUsed'], 110 | start_date=response_data['startDate'], 111 | end_date=response_data['endDate'] 112 | ) 113 | 114 | def with_incremented_usage(self) -> 'CreditTrackingResponse': 115 | """Return a copy with credits_used incremented by 1.""" 116 | return CreditTrackingResponse( 117 | organization_id=self.organization_id, 118 | enabled=self.enabled, 119 | max_credits=self.max_credits, 120 | credits_used=self.credits_used + 1, 121 | start_date=self.start_date, 122 | end_date=self.end_date 123 | ) 124 | -------------------------------------------------------------------------------- /test/test_vulnerability_models.py: -------------------------------------------------------------------------------- 1 | """Tests for vulnerability models without VulnerabilityStatus. 2 | 3 | This module contains test cases for the vulnerability domain models, 4 | focusing on the core functionality after removing status-based features. 5 | """ 6 | 7 | import unittest 8 | 9 | # Test setup imports (path is set up by conftest.py) 10 | from src.smartfix.domains.vulnerability.models import ( 11 | Vulnerability, VulnerabilitySeverity 12 | ) 13 | 14 | 15 | class TestVulnerabilitySeverity(unittest.TestCase): 16 | """Test cases for VulnerabilitySeverity enum.""" 17 | 18 | def test_enum_values(self): 19 | """Test that all expected severity values exist.""" 20 | expected_values = ['CRITICAL', 'HIGH', 'MEDIUM', 'LOW', 'NOTE'] 21 | actual_values = [severity.value for severity in VulnerabilitySeverity] 22 | self.assertEqual(set(expected_values), set(actual_values)) 23 | 24 | def test_severity_creation(self): 25 | """Test severity enum creation.""" 26 | critical_severity = VulnerabilitySeverity.CRITICAL 27 | low_severity = VulnerabilitySeverity.LOW 28 | 29 | self.assertEqual(critical_severity.value, 'CRITICAL') 30 | self.assertEqual(low_severity.value, 'LOW') 31 | 32 | 33 | class TestVulnerability(unittest.TestCase): 34 | """Test cases for Vulnerability aggregate root.""" 35 | 36 | def setUp(self): 37 | """Set up test fixtures.""" 38 | self.sample_vulnerability = Vulnerability( 39 | uuid="vuln-123-456-789", 40 | title="SQL Injection in login function", 41 | rule_name="sql-injection", 42 | severity=VulnerabilitySeverity.HIGH, 43 | description="Potential SQL injection vulnerability detected", 44 | cwe_id="CWE-89" 45 | ) 46 | 47 | def test_vulnerability_creation(self): 48 | """Test vulnerability instance creation.""" 49 | vuln = Vulnerability( 50 | uuid="test-uuid", 51 | title="Test Vulnerability", 52 | rule_name="test-rule", 53 | severity=VulnerabilitySeverity.MEDIUM 54 | ) 55 | 56 | self.assertEqual(vuln.uuid, "test-uuid") 57 | self.assertEqual(vuln.title, "Test Vulnerability") 58 | self.assertEqual(vuln.rule_name, "test-rule") 59 | self.assertEqual(vuln.severity, VulnerabilitySeverity.MEDIUM) 60 | 61 | def test_vulnerability_properties(self): 62 | """Test vulnerability property access.""" 63 | self.assertEqual(self.sample_vulnerability.uuid, "vuln-123-456-789") 64 | self.assertEqual(self.sample_vulnerability.title, "SQL Injection in login function") 65 | self.assertEqual(self.sample_vulnerability.rule_name, "sql-injection") 66 | self.assertEqual(self.sample_vulnerability.severity, VulnerabilitySeverity.HIGH) 67 | 68 | def test_to_dict(self): 69 | """Test vulnerability dictionary conversion.""" 70 | vuln = Vulnerability( 71 | uuid="dict-test", 72 | title="Dict Test Vulnerability", 73 | rule_name="dict-rule", 74 | severity=VulnerabilitySeverity.LOW 75 | ) 76 | 77 | vuln_dict = vuln.to_dict() 78 | 79 | self.assertEqual(vuln_dict['uuid'], "dict-test") 80 | self.assertEqual(vuln_dict['title'], "Dict Test Vulnerability") 81 | self.assertEqual(vuln_dict['rule_name'], "dict-rule") 82 | self.assertEqual(vuln_dict['severity'], VulnerabilitySeverity.LOW.value) 83 | 84 | def test_from_api_data(self): 85 | """Test vulnerability creation from API data.""" 86 | api_data = { 87 | 'vulnerabilityUuid': 'api-uuid-123', 88 | 'vulnerabilityTitle': 'API Test Vulnerability', 89 | 'vulnerabilityRuleName': 'api-test-rule', 90 | 'vulnerabilitySeverity': 'HIGH', 91 | 'remediationId': 'remediation-123', 92 | 'description': 'Test description', 93 | 'cweId': 'CWE-123', 94 | 'metadata': {'test': 'data'} 95 | } 96 | 97 | vuln = Vulnerability.from_api_data(api_data) 98 | 99 | self.assertEqual(vuln.uuid, 'api-uuid-123') 100 | self.assertEqual(vuln.title, 'API Test Vulnerability') 101 | self.assertEqual(vuln.rule_name, 'api-test-rule') 102 | self.assertEqual(vuln.severity, VulnerabilitySeverity.HIGH) 103 | self.assertEqual(vuln.description, 'Test description') 104 | self.assertEqual(vuln.cwe_id, 'CWE-123') 105 | self.assertEqual(vuln.metadata, {'test': 'data'}) 106 | 107 | def test_from_api_data_minimal(self): 108 | """Test vulnerability creation from minimal API data.""" 109 | api_data = { 110 | 'vulnerabilityUuid': 'minimal-uuid', 111 | 'vulnerabilityTitle': 'Minimal Vulnerability', 112 | 'vulnerabilityRuleName': 'minimal-rule', 113 | 'vulnerabilitySeverity': 'MEDIUM' 114 | } 115 | 116 | vuln = Vulnerability.from_api_data(api_data) 117 | 118 | self.assertEqual(vuln.uuid, 'minimal-uuid') 119 | self.assertEqual(vuln.title, 'Minimal Vulnerability') 120 | self.assertEqual(vuln.rule_name, 'minimal-rule') 121 | self.assertEqual(vuln.severity, VulnerabilitySeverity.MEDIUM) 122 | self.assertIsNone(vuln.description) 123 | self.assertIsNone(vuln.cwe_id) 124 | self.assertEqual(vuln.metadata, {}) 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /test/test_utils_error_exit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | from unittest.mock import patch 23 | from contextlib import contextmanager 24 | 25 | # Test setup imports (path is set up by conftest.py) 26 | from src import utils 27 | from src.config import get_config, reset_config 28 | from src.smartfix.shared.failure_categories import FailureCategory 29 | 30 | 31 | class TestErrorExit(unittest.TestCase): 32 | """Tests for the error_exit function in utils.py""" 33 | 34 | def setUp(self): 35 | """Set up test environment before each test""" 36 | reset_config() # Reset the config singleton 37 | 38 | def tearDown(self): 39 | """Clean up after each test""" 40 | reset_config() 41 | 42 | @contextmanager 43 | def assert_system_exit(self, expected_code=1): 44 | """Context manager to assert that sys.exit was called with the expected code""" 45 | with self.assertRaises(SystemExit) as cm: 46 | yield 47 | self.assertEqual(cm.exception.code, expected_code) 48 | 49 | @patch('sys.exit') 50 | @patch('src.utils.log') # Directly patch the module function 51 | @patch('src.git_handler.cleanup_branch') 52 | @patch('src.git_handler.get_branch_name') 53 | @patch('src.contrast_api.send_telemetry_data') 54 | @patch('src.contrast_api.notify_remediation_failed') 55 | def test_error_exit_with_failure_code(self, mock_notify, mock_send_telemetry, mock_get_branch, 56 | mock_cleanup, mock_log, mock_exit): 57 | """Test error_exit when a specific failure code is provided""" 58 | # Setup 59 | remediation_id = "test-remediation-id" 60 | failure_code = FailureCategory.AGENT_FAILURE.value 61 | mock_notify.return_value = True # Notification succeeds 62 | mock_get_branch.return_value = f"smartfix/remediation-{remediation_id}" 63 | config = get_config(testing=True) 64 | 65 | # Execute the function 66 | utils.error_exit(remediation_id, failure_code) 67 | 68 | # Assert 69 | mock_notify.assert_called_once_with( 70 | remediation_id=remediation_id, 71 | failure_category=failure_code, 72 | contrast_host=config.CONTRAST_HOST, 73 | contrast_org_id=config.CONTRAST_ORG_ID, 74 | contrast_app_id=config.CONTRAST_APP_ID, 75 | contrast_auth_key=config.CONTRAST_AUTHORIZATION_KEY, 76 | contrast_api_key=config.CONTRAST_API_KEY 77 | ) 78 | 79 | # Verify other function calls 80 | mock_get_branch.assert_called_once_with(remediation_id) 81 | mock_cleanup.assert_called_once_with(f"smartfix/remediation-{remediation_id}") 82 | mock_send_telemetry.assert_called_once() 83 | # Verify sys.exit was called with code 1 84 | mock_exit.assert_called_once_with(1) 85 | 86 | @patch('sys.exit') 87 | @patch('src.utils.log') 88 | @patch('src.git_handler.cleanup_branch') 89 | @patch('src.git_handler.get_branch_name') 90 | @patch('src.contrast_api.send_telemetry_data') 91 | @patch('src.contrast_api.notify_remediation_failed') 92 | def test_error_exit_default_failure_code(self, mock_notify, mock_send_telemetry, mock_get_branch, 93 | mock_cleanup, mock_log, mock_exit): 94 | """Test error_exit when no failure code is provided (uses default)""" 95 | # Setup 96 | remediation_id = "test-remediation-id" 97 | default_failure_code = FailureCategory.GENERAL_FAILURE.value 98 | mock_notify.return_value = True 99 | mock_get_branch.return_value = f"smartfix/remediation-{remediation_id}" 100 | config = get_config(testing=True) 101 | 102 | # Execute 103 | utils.error_exit(remediation_id) 104 | 105 | # Assert 106 | mock_notify.assert_called_once_with( 107 | remediation_id=remediation_id, 108 | failure_category=default_failure_code, 109 | contrast_host=config.CONTRAST_HOST, 110 | contrast_org_id=config.CONTRAST_ORG_ID, 111 | contrast_app_id=config.CONTRAST_APP_ID, 112 | contrast_auth_key=config.CONTRAST_AUTHORIZATION_KEY, 113 | contrast_api_key=config.CONTRAST_API_KEY 114 | ) 115 | 116 | # Verify other functions were called 117 | mock_get_branch.assert_called_once_with(remediation_id) 118 | mock_cleanup.assert_called_once() 119 | mock_send_telemetry.assert_called_once() 120 | # Verify sys.exit was called with code 1 121 | mock_exit.assert_called_once_with(1) 122 | 123 | 124 | if __name__ == '__main__': 125 | unittest.main() 126 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SmartFix 2 | 3 | Thank you for your interest in contributing to SmartFix! This guide will help you get set up for development. 4 | 5 | ## ⚖️ Contributor License Agreement 6 | 7 | **External contributors must have a signed Contributor License Agreement (CLA) on file with Contrast Security before we can accept code contributions.** 8 | 9 | If you or your company do not have a CLA on file with Contrast Security, please contact us before submitting pull requests. This requirement does not apply to Contrast Security employees. 10 | 11 | ## 🚀 Quick Start 12 | 13 | ### 1. Clone and Setup 14 | 15 | ```bash 16 | git clone 17 | cd contrast-ai-smartfix-action 18 | 19 | # Install git hooks for automatic linting 20 | ./setup-hooks.sh 21 | ``` 22 | 23 | ### 2. Install Dependencies 24 | 25 | ```bash 26 | # Install Python dependencies (creates .venv and installs packages) 27 | ./test/run_tests.sh --skip-install # Will prompt to install uv if needed 28 | 29 | # Or install manually 30 | pip install -r src/requirements.txt 31 | pip install flake8 # For linting (if not already installed by setup-hooks.sh) 32 | ``` 33 | 34 | ### 3. Verify Setup 35 | 36 | ```bash 37 | # Run tests to ensure everything works 38 | ./test/run_tests.sh 39 | 40 | # Run linting to check code quality 41 | ./.git/hooks/pre-push 42 | ``` 43 | 44 | ## 🔧 Development Workflow 45 | 46 | ### Code Quality & Linting 47 | 48 | We use automated linting to maintain code quality with **single source of truth** approach: 49 | 50 | - **Local Development**: Git hooks automatically run linting checks 51 | - **CI/CD**: Uses the same hook scripts to ensure consistency 52 | 53 | #### Git Hooks (Automatic) 54 | - **Pre-commit**: Cleans trailing whitespace 55 | - **Pre-push**: Runs Python linting with flake8 56 | 57 | #### Manual Linting 58 | ```bash 59 | # Run the same linting that CI uses 60 | ./.git/hooks/pre-push 61 | 62 | # Or run flake8 directly (but hook is preferred for consistency) 63 | flake8 src/ test/ 64 | ``` 65 | 66 | #### Bypassing Hooks (Not Recommended) 67 | ```bash 68 | git commit --no-verify # Skip pre-commit 69 | git push --no-verify # Skip pre-push 70 | ``` 71 | 72 | ### Running Tests 73 | 74 | ```bash 75 | # Run all tests (installs deps if needed) 76 | ./test/run_tests.sh 77 | 78 | # Run tests without installing deps 79 | ./test/run_tests.sh --skip-install 80 | 81 | # Run specific test file 82 | ./test/run_tests.sh test_main.py 83 | 84 | # Run multiple specific tests 85 | ./test/run_tests.sh test_main.py test_config.py 86 | ``` 87 | 88 | ### Code Style Guidelines 89 | 90 | - **Linting**: Follows flake8 default configuration 91 | - **Imports**: Group standard library, third-party, and local imports separately 92 | - **Whitespace**: No trailing whitespace (automatically cleaned by pre-commit hook) 93 | - **Comments**: Use clear, concise comments for complex logic 94 | 95 | ## 🧪 Testing Guidelines 96 | 97 | - Write tests for new functionality 98 | - Maintain existing test coverage 99 | - Use descriptive test names 100 | - Mock external dependencies (API calls, file system, etc.) 101 | 102 | ## 📋 Pull Request Process 103 | 104 | 1. **Create Feature Branch** 105 | ```bash 106 | git checkout -b feature/your-feature-name 107 | ``` 108 | 109 | 2. **Make Changes** 110 | - Write code following style guidelines 111 | - Add/update tests as needed 112 | - Run linting and tests locally: 113 | ```bash 114 | ./.git/hooks/pre-push # Linting 115 | ./test/run_tests.sh # Tests 116 | ``` 117 | 118 | 3. **Commit Changes** 119 | ```bash 120 | git add . 121 | git commit -m "Add feature: description of changes" 122 | # Pre-commit hook automatically cleans whitespace 123 | ``` 124 | 125 | 4. **Push and Create PR** 126 | ```bash 127 | git push origin your-feature-name 128 | # Pre-push hook automatically runs linting 129 | # Create pull request on GitHub 130 | ``` 131 | 132 | 5. **PR Requirements** 133 | - ✅ All CI checks pass (uses same hook scripts) 134 | - ✅ Code review approval 135 | - ✅ Branch is up-to-date with main 136 | 137 | ## 🔍 Debugging & Development 138 | 139 | ### Local Development Commands 140 | 141 | ```bash 142 | # Full development cycle 143 | ./setup-hooks.sh # One-time setup 144 | ./test/run_tests.sh # Run tests 145 | ./.git/hooks/pre-push # Run linting 146 | 147 | # During development 148 | git commit # Triggers pre-commit (whitespace cleanup) 149 | git push # Triggers pre-push (linting) 150 | ``` 151 | 152 | ### Environment Variables 153 | 154 | The test runner (`./test/run_tests.sh`) automatically sets up test environment variables. For manual testing, you'll need: 155 | 156 | ```bash 157 | export BASE_BRANCH="main" 158 | export BUILD_COMMAND="echo 'Mock build'" 159 | export FORMATTING_COMMAND="echo 'Mock format'" 160 | export GITHUB_TOKEN="your-github-token" 161 | export CONTRAST_HOST="https://your.contrast.host" 162 | export CONTRAST_ORG_ID="your-org-id" 163 | export CONTRAST_APP_ID="your-app-id" 164 | export CONTRAST_AUTHORIZATION_KEY="your-auth-key" 165 | export CONTRAST_API_KEY="your-api-key" 166 | export DEBUG_MODE="true" 167 | ``` 168 | 169 | ### Debugging Tips 170 | 171 | - Use `DEBUG_MODE=true` for verbose logging 172 | 173 | ## 🐛 Troubleshooting 174 | 175 | #### Tests Not Working 176 | ```bash 177 | # Use the test script which handles all setup 178 | ./test/run_tests.sh 179 | 180 | # If UV is missing, install it: 181 | pip install uv 182 | ``` 183 | 184 | Thank you for contributing to SmartFix! 🙏 -------------------------------------------------------------------------------- /test/test_aws_bearer_token_bedrock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | """ 22 | Unit tests for AWS_BEARER_TOKEN_BEDROCK environment variable support. 23 | 24 | This module tests that the AWS_BEARER_TOKEN_BEDROCK environment variable 25 | is properly handled and available for LiteLLM to use for Bedrock authentication. 26 | """ 27 | 28 | import sys 29 | import unittest 30 | import os 31 | 32 | # Add project root to path for imports 33 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 34 | 35 | # Import config module (need to initialize before other imports) 36 | from src.config import get_config # noqa: E402 37 | 38 | # Initialize config with testing flag 39 | _ = get_config(testing=True) 40 | 41 | 42 | class TestAwsBearerTokenBedrock(unittest.TestCase): 43 | """Test cases for AWS_BEARER_TOKEN_BEDROCK environment variable support.""" 44 | 45 | def test_environment_variable_can_be_set(self): 46 | """Test that AWS_BEARER_TOKEN_BEDROCK environment variable can be set and retrieved.""" 47 | test_token = "test-bearer-token-12345" 48 | 49 | # Set the environment variable 50 | os.environ['AWS_BEARER_TOKEN_BEDROCK'] = test_token 51 | 52 | # Verify it can be retrieved 53 | self.assertEqual(os.environ.get('AWS_BEARER_TOKEN_BEDROCK'), test_token) 54 | 55 | # Clean up 56 | del os.environ['AWS_BEARER_TOKEN_BEDROCK'] 57 | 58 | def test_environment_variable_not_set_returns_none(self): 59 | """Test that missing AWS_BEARER_TOKEN_BEDROCK returns None.""" 60 | # Ensure the variable is not set 61 | if 'AWS_BEARER_TOKEN_BEDROCK' in os.environ: 62 | del os.environ['AWS_BEARER_TOKEN_BEDROCK'] 63 | 64 | # Verify it returns None when not set 65 | self.assertIsNone(os.environ.get('AWS_BEARER_TOKEN_BEDROCK')) 66 | 67 | def test_bearer_token_precedence_over_iam(self): 68 | """Test that AWS_BEARER_TOKEN_BEDROCK can coexist with IAM credentials.""" 69 | # Set both bearer token and IAM credentials 70 | os.environ['AWS_BEARER_TOKEN_BEDROCK'] = "test-bearer-token" 71 | os.environ['AWS_ACCESS_KEY_ID'] = "test-access-key" 72 | os.environ['AWS_SECRET_ACCESS_KEY'] = "test-secret-key" 73 | 74 | try: 75 | # Verify both are set (LiteLLM will determine which to use based on its own logic) 76 | self.assertIsNotNone(os.environ.get('AWS_BEARER_TOKEN_BEDROCK')) 77 | self.assertIsNotNone(os.environ.get('AWS_ACCESS_KEY_ID')) 78 | self.assertIsNotNone(os.environ.get('AWS_SECRET_ACCESS_KEY')) 79 | 80 | finally: 81 | # Clean up 82 | for key in ['AWS_BEARER_TOKEN_BEDROCK', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY']: 83 | if key in os.environ: 84 | del os.environ[key] 85 | 86 | def test_empty_bearer_token_is_ignored(self): 87 | """Test that an empty AWS_BEARER_TOKEN_BEDROCK value is handled gracefully.""" 88 | # Set an empty bearer token 89 | os.environ['AWS_BEARER_TOKEN_BEDROCK'] = "" 90 | 91 | try: 92 | # Verify it's set but empty 93 | self.assertEqual(os.environ.get('AWS_BEARER_TOKEN_BEDROCK'), "") 94 | 95 | # An empty token should be falsy 96 | self.assertFalse(os.environ.get('AWS_BEARER_TOKEN_BEDROCK')) 97 | 98 | finally: 99 | # Clean up 100 | if 'AWS_BEARER_TOKEN_BEDROCK' in os.environ: 101 | del os.environ['AWS_BEARER_TOKEN_BEDROCK'] 102 | 103 | def test_aws_region_name_can_be_set(self): 104 | """Test that AWS_REGION_NAME environment variable can be set and retrieved.""" 105 | test_region = "us-east-1" 106 | 107 | # Set the environment variable 108 | os.environ['AWS_REGION_NAME'] = test_region 109 | 110 | try: 111 | # Verify it can be retrieved 112 | self.assertEqual(os.environ.get('AWS_REGION_NAME'), test_region) 113 | 114 | finally: 115 | # Clean up 116 | if 'AWS_REGION_NAME' in os.environ: 117 | del os.environ['AWS_REGION_NAME'] 118 | 119 | def test_bearer_token_and_region_together(self): 120 | """Test that AWS_BEARER_TOKEN_BEDROCK and AWS_REGION_NAME can be used together.""" 121 | test_token = "test-bearer-token-abc123" 122 | test_region = "us-west-2" 123 | 124 | # Set both environment variables 125 | os.environ['AWS_BEARER_TOKEN_BEDROCK'] = test_token 126 | os.environ['AWS_REGION_NAME'] = test_region 127 | 128 | try: 129 | # Verify both are set correctly 130 | self.assertEqual(os.environ.get('AWS_BEARER_TOKEN_BEDROCK'), test_token) 131 | self.assertEqual(os.environ.get('AWS_REGION_NAME'), test_region) 132 | 133 | finally: 134 | # Clean up 135 | for key in ['AWS_BEARER_TOKEN_BEDROCK', 'AWS_REGION_NAME']: 136 | if key in os.environ: 137 | del os.environ[key] 138 | 139 | 140 | if __name__ == '__main__': 141 | unittest.main() 142 | -------------------------------------------------------------------------------- /test/test_contrast_llm_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | import os 5 | from src.config import get_config, reset_config 6 | 7 | 8 | class TestContrastLlmConfig(unittest.TestCase): 9 | """Test cases for the USE_CONTRAST_LLM configuration setting.""" 10 | 11 | def setUp(self): 12 | """Set up test environment before each test.""" 13 | # Store original environment to restore later 14 | self.original_env = os.environ.copy() 15 | 16 | # Set up minimal required environment variables for testing 17 | self.env_vars = { 18 | 'GITHUB_WORKSPACE': '/tmp', 19 | 'BUILD_COMMAND': 'echo "Mock build"', 20 | 'GITHUB_TOKEN': 'mock-token', 21 | 'GITHUB_REPOSITORY': 'mock/repo', 22 | 'BASE_BRANCH': 'main', 23 | 'CONTRAST_HOST': 'test.contrastsecurity.com', 24 | 'CONTRAST_ORG_ID': 'test-org-id', 25 | 'CONTRAST_APP_ID': 'test-app-id', 26 | 'CONTRAST_AUTHORIZATION_KEY': 'test-auth-key', 27 | 'CONTRAST_API_KEY': 'test-api-key' 28 | } 29 | 30 | os.environ.update(self.env_vars) 31 | reset_config() 32 | 33 | def tearDown(self): 34 | """Clean up after each test.""" 35 | # Restore original environment 36 | os.environ.clear() 37 | os.environ.update(self.original_env) 38 | reset_config() 39 | 40 | def test_use_contrast_llm_default_value(self): 41 | """Test that USE_CONTRAST_LLM defaults to True when not set.""" 42 | # Ensure USE_CONTRAST_LLM is not in environment 43 | if 'USE_CONTRAST_LLM' in os.environ: 44 | del os.environ['USE_CONTRAST_LLM'] 45 | 46 | reset_config() 47 | config = get_config(testing=True) 48 | 49 | # Should default to True 50 | self.assertTrue(config.USE_CONTRAST_LLM) 51 | 52 | def test_use_contrast_llm_set_to_true(self): 53 | """Test that USE_CONTRAST_LLM can be set to true.""" 54 | os.environ['USE_CONTRAST_LLM'] = 'true' 55 | reset_config() 56 | config = get_config(testing=True) 57 | 58 | self.assertTrue(config.USE_CONTRAST_LLM) 59 | 60 | def test_use_contrast_llm_set_to_false(self): 61 | """Test that USE_CONTRAST_LLM can be set to false.""" 62 | os.environ['USE_CONTRAST_LLM'] = 'false' 63 | reset_config() 64 | config = get_config(testing=True) 65 | 66 | self.assertFalse(config.USE_CONTRAST_LLM) 67 | 68 | def test_use_contrast_llm_case_insensitive_true(self): 69 | """Test that USE_CONTRAST_LLM accepts 'TRUE' (case insensitive).""" 70 | os.environ['USE_CONTRAST_LLM'] = 'TRUE' 71 | reset_config() 72 | config = get_config(testing=True) 73 | 74 | self.assertTrue(config.USE_CONTRAST_LLM) 75 | 76 | def test_use_contrast_llm_case_insensitive_false(self): 77 | """Test that USE_CONTRAST_LLM accepts 'FALSE' (case insensitive).""" 78 | os.environ['USE_CONTRAST_LLM'] = 'FALSE' 79 | reset_config() 80 | config = get_config(testing=True) 81 | 82 | self.assertFalse(config.USE_CONTRAST_LLM) 83 | 84 | def test_use_contrast_llm_mixed_case(self): 85 | """Test that USE_CONTRAST_LLM accepts mixed case values.""" 86 | test_cases = [ 87 | ('True', True), 88 | ('False', False), 89 | ('tRuE', True), 90 | ('FaLsE', False) 91 | ] 92 | 93 | for env_value, expected in test_cases: 94 | with self.subTest(env_value=env_value): 95 | os.environ['USE_CONTRAST_LLM'] = env_value 96 | reset_config() 97 | config = get_config(testing=True) 98 | self.assertEqual(config.USE_CONTRAST_LLM, expected) 99 | 100 | def test_use_contrast_llm_invalid_values_default_to_false(self): 101 | """Test that invalid values for USE_CONTRAST_LLM default to False.""" 102 | invalid_values = ['yes', 'no', '1', '0', 'invalid'] 103 | 104 | for invalid_value in invalid_values: 105 | with self.subTest(invalid_value=invalid_value): 106 | os.environ['USE_CONTRAST_LLM'] = invalid_value 107 | reset_config() 108 | config = get_config(testing=True) 109 | # Invalid values should result in False (not 'true') 110 | self.assertFalse(config.USE_CONTRAST_LLM) 111 | 112 | def test_use_contrast_llm_empty_string_uses_default(self): 113 | """Test that empty string for USE_CONTRAST_LLM uses the default value (True).""" 114 | os.environ['USE_CONTRAST_LLM'] = '' 115 | reset_config() 116 | config = get_config(testing=True) 117 | 118 | # Empty string should fall back to default=True 119 | self.assertTrue(config.USE_CONTRAST_LLM) 120 | 121 | def test_use_contrast_llm_debug_logging(self): 122 | """Test that USE_CONTRAST_LLM appears in debug logging when DEBUG_MODE is enabled.""" 123 | os.environ['USE_CONTRAST_LLM'] = 'false' 124 | os.environ['DEBUG_MODE'] = 'true' 125 | reset_config() 126 | 127 | # Create config to trigger debug logging 128 | config = get_config(testing=True) 129 | 130 | # Verify the setting is correct 131 | self.assertFalse(config.USE_CONTRAST_LLM) 132 | self.assertTrue(config.DEBUG_MODE) 133 | 134 | def test_environment_variable_can_be_set_and_retrieved(self): 135 | """Test that USE_CONTRAST_LLM environment variable can be set and retrieved.""" 136 | # Set the environment variable 137 | os.environ['USE_CONTRAST_LLM'] = "false" 138 | 139 | # Verify it can be retrieved from os.environ 140 | self.assertEqual(os.environ.get('USE_CONTRAST_LLM'), "false") 141 | 142 | # Verify it works through the config system 143 | reset_config() 144 | config = get_config(testing=True) 145 | self.assertFalse(config.USE_CONTRAST_LLM) 146 | 147 | 148 | if __name__ == '__main__': 149 | unittest.main() 150 | -------------------------------------------------------------------------------- /src/smartfix/domains/workflow/session_handler.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security's commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | """ 21 | Session handling workflow for SmartFix agent results. 22 | 23 | This module provides object-oriented handling of agent session results, 24 | including QA section generation, success/failure determination, and validation. 25 | """ 26 | 27 | from typing import Optional 28 | from src.smartfix.shared.failure_categories import FailureCategory 29 | from src.utils import log 30 | 31 | 32 | class SessionResult: 33 | """ 34 | Encapsulates the result of processing an agent session. 35 | 36 | Attributes: 37 | should_continue: Whether processing should continue to PR creation 38 | failure_category: Failure category if session failed 39 | ai_fix_summary: Summary for successful sessions 40 | """ 41 | 42 | def __init__(self, should_continue: bool, failure_category: Optional[str] = None, ai_fix_summary: Optional[str] = None): 43 | self.should_continue = should_continue 44 | self.failure_category = failure_category 45 | self.ai_fix_summary = ai_fix_summary 46 | 47 | 48 | class QASectionConfig: 49 | """ 50 | Configuration for QA section generation. 51 | 52 | Attributes: 53 | skip_qa_review: Whether QA review was skipped by configuration 54 | has_build_command: Whether a build command is available 55 | build_command: The build command used 56 | """ 57 | 58 | def __init__(self, skip_qa_review: bool, has_build_command: bool, build_command: str): 59 | self.skip_qa_review = skip_qa_review 60 | self.has_build_command = has_build_command 61 | self.build_command = build_command 62 | 63 | 64 | class SessionHandler: 65 | """ 66 | Handles SmartFix agent session results and generates appropriate responses. 67 | 68 | This class encapsulates the business logic for: 69 | - Determining session success/failure outcomes 70 | - Generating QA sections for PR bodies 71 | """ 72 | 73 | def __init__(self): 74 | pass 75 | 76 | def handle_session_result(self, session) -> SessionResult: 77 | """ 78 | Handle session result and determine next action. 79 | 80 | Args: 81 | session: AgentSession with success, failure_category, pr_body properties 82 | 83 | Returns: 84 | SessionResult: Result indicating whether to continue processing 85 | """ 86 | if session.success: 87 | ai_fix_summary = session.pr_body if session.pr_body else "Fix completed successfully" 88 | return SessionResult(should_continue=True, ai_fix_summary=ai_fix_summary) 89 | else: 90 | # Agent failed - determine failure category 91 | failure_category = ( 92 | session.failure_category.value 93 | if session.failure_category 94 | else FailureCategory.AGENT_FAILURE.value 95 | ) 96 | return SessionResult(should_continue=False, failure_category=failure_category) 97 | 98 | def generate_qa_section(self, session, config: QASectionConfig) -> str: 99 | """ 100 | Generate the QA section for PR body based on session results. 101 | 102 | Args: 103 | session: AgentSession with success, qa_attempts properties 104 | config: QA section configuration 105 | 106 | Returns: 107 | str: QA section for PR body 108 | """ 109 | 110 | # Note: At this point session.success must be True 111 | # (failures are handled by handle_session_result earlier) 112 | # Start with standard QA section header 113 | qa_section = "\n\n---\n\n## Review \n\n" 114 | if not config.skip_qa_review and config.has_build_command: 115 | # QA was expected to run - check session results 116 | if session.qa_attempts > 0: 117 | # QA loop ran and eventually succeeded 118 | qa_section += f"* **Build Run:** Yes (`{config.build_command}`)\n" 119 | qa_section += "* **Final Build Status:** Success \n" 120 | else: 121 | # Build passed on first attempt, no QA needed 122 | qa_section += f"* **Build Run:** Yes (`{config.build_command}`)\n" 123 | qa_section += "* **Final Build Status:** Success (passed on first attempt)\n" 124 | else: 125 | # QA was skipped - provide empty section and log reason 126 | qa_section = "" 127 | self._log_qa_skip_reason(config) 128 | 129 | return qa_section 130 | 131 | def _log_qa_skip_reason(self, config: QASectionConfig) -> None: 132 | """ 133 | Log the reason why QA review was skipped. 134 | 135 | Args: 136 | config: QA section configuration 137 | """ 138 | if config.skip_qa_review: 139 | log("QA Review was skipped based on SKIP_QA_REVIEW setting.") 140 | elif not config.has_build_command: 141 | log("QA Review was skipped as no BUILD_COMMAND was provided.") 142 | 143 | 144 | # Factory function for backward compatibility and easy instantiation 145 | def create_session_handler() -> SessionHandler: 146 | """ 147 | Create a SessionHandler instance. 148 | 149 | Returns: 150 | SessionHandler: Configured session handler 151 | """ 152 | return SessionHandler() 153 | -------------------------------------------------------------------------------- /test/test_smartfix_llm_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | """ 22 | Unit tests for SmartFixLlmAgent class. 23 | 24 | This module tests the SmartFixLlmAgent functionality with focused tests on 25 | the extension logic without complex ADK dependencies. 26 | """ 27 | 28 | import unittest 29 | import json 30 | from unittest.mock import Mock, patch, MagicMock 31 | 32 | # Test setup imports (path is set up by conftest.py) 33 | from src.smartfix.extensions.smartfix_llm_agent import SmartFixLlmAgent 34 | from src.smartfix.extensions.smartfix_litellm import SmartFixLiteLlm 35 | 36 | 37 | class TestSmartFixLlmAgentFunctionality(unittest.TestCase): 38 | """Test cases focusing on SmartFixLlmAgent specific functionality.""" 39 | 40 | def test_has_extended_model_true(self): 41 | """Test has_extended_model returns True when SmartFixLiteLlm reference exists.""" 42 | # Test the method logic directly without full object instantiation 43 | agent = MagicMock() 44 | agent.canonical_model = Mock(spec=SmartFixLiteLlm) 45 | 46 | # Apply the real method to our mock 47 | result = SmartFixLlmAgent.has_extended_model(agent) 48 | self.assertTrue(result) 49 | 50 | def test_has_extended_model_false(self): 51 | """Test has_extended_model returns False when no SmartFixLiteLlm reference exists.""" 52 | agent = MagicMock() 53 | agent.canonical_model = Mock() # Not a SmartFixLiteLlm 54 | 55 | result = SmartFixLlmAgent.has_extended_model(agent) 56 | self.assertFalse(result) 57 | 58 | def test_get_extended_model_with_reference(self): 59 | """Test get_extended_model returns the canonical model when it's SmartFixLiteLlm.""" 60 | agent = MagicMock() 61 | mock_extended_model = Mock(spec=SmartFixLiteLlm) 62 | agent.canonical_model = mock_extended_model 63 | 64 | # Mock has_extended_model to return True 65 | with patch.object(SmartFixLlmAgent, 'has_extended_model', return_value=True): 66 | result = SmartFixLlmAgent.get_extended_model(agent) 67 | self.assertIs(result, mock_extended_model) 68 | 69 | def test_reset_accumulated_stats_with_extended_model(self): 70 | """Test that reset delegates to SmartFixLiteLlm model.""" 71 | agent = MagicMock() 72 | mock_extended_model = Mock(spec=SmartFixLiteLlm) 73 | 74 | # The real method calls get_extended_model first 75 | with patch.object(agent, 'get_extended_model', return_value=mock_extended_model): 76 | SmartFixLlmAgent.reset_accumulated_stats(agent) 77 | # Should delegate to the extended model's reset method 78 | mock_extended_model.reset_accumulated_stats.assert_called_once() 79 | 80 | 81 | class TestSmartFixLlmAgentIntegration(unittest.TestCase): 82 | """Integration tests for SmartFixLlmAgent with real SmartFixLiteLlm instances.""" 83 | 84 | @patch('litellm.completion') 85 | def test_extended_model_delegation_logic(self, mock_completion): 86 | """Test that SmartFixLlmAgent logic works with SmartFixLiteLlm.""" 87 | # Create a real SmartFixLiteLlm instance 88 | extended_model = SmartFixLiteLlm(model="test-model") 89 | 90 | # Add some usage to the accumulator to simulate usage 91 | extended_model.cost_accumulator.add_usage( 92 | input_tokens=150, 93 | output_tokens=75, 94 | cache_read_tokens=50, 95 | cache_write_tokens=25, 96 | new_input_cost=0.0015, 97 | cache_read_cost=0.0001, 98 | cache_write_cost=0.0008, 99 | output_cost=0.003 100 | ) 101 | 102 | # Test that the extended model has the expected methods and data 103 | self.assertTrue(hasattr(extended_model, 'gather_accumulated_stats_dict')) 104 | self.assertTrue(hasattr(extended_model, 'gather_accumulated_stats')) 105 | self.assertTrue(hasattr(extended_model, 'reset_accumulated_stats')) 106 | 107 | # Test that we can get statistics from the extended model 108 | stats_dict = extended_model.gather_accumulated_stats_dict() 109 | self.assertEqual(stats_dict['call_count'], 1) 110 | self.assertEqual(stats_dict['token_usage']['total_tokens'], 300) # 150 + 75 + 50 + 25 111 | 112 | # Test JSON export 113 | json_stats = extended_model.gather_accumulated_stats() 114 | self.assertIsInstance(json_stats, str) 115 | parsed_stats = json.loads(json_stats) 116 | self.assertEqual(parsed_stats['call_count'], 1) 117 | 118 | # Test reset functionality 119 | extended_model.reset_accumulated_stats() 120 | self.assertEqual(extended_model.cost_accumulator.call_count, 0) 121 | 122 | @patch('litellm.completion') 123 | def test_model_info_functionality(self, mock_completion): 124 | """Test get_model_info provides correct information.""" 125 | # Create a real SmartFixLiteLlm instance 126 | extended_model = SmartFixLiteLlm(model="test-model-id") 127 | 128 | # Create a mock agent 129 | agent = MagicMock() 130 | agent.name = "test-agent" 131 | agent.canonical_model = extended_model 132 | agent.original_extended_model = extended_model 133 | 134 | result = SmartFixLlmAgent.get_model_info(agent) 135 | 136 | self.assertEqual(result['agent_name'], "test-agent") 137 | self.assertEqual(result['model_name'], "test-model-id") 138 | self.assertEqual(result['model_type'], "SmartFixLiteLlm") 139 | self.assertTrue(result['is_extended']) 140 | self.assertTrue(result['has_stats']) 141 | self.assertIn('model_id', result) 142 | 143 | 144 | if __name__ == '__main__': 145 | unittest.main() 146 | -------------------------------------------------------------------------------- /docs/contrast-ai-smartfix.yml.template: -------------------------------------------------------------------------------- 1 | #- 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security’s commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | name: Contrast AI SmartFix 21 | 22 | on: 23 | pull_request: 24 | types: 25 | - closed 26 | schedule: 27 | - cron: '0 0 * * *' # <-- Customer configured schedule 28 | workflow_dispatch: # Allows manual triggering 29 | 30 | permissions: 31 | contents: write 32 | pull-requests: write 33 | 34 | jobs: 35 | generate_fixes: 36 | name: Generate Fixes 37 | runs-on: ubuntu-latest 38 | if: github.event_name == 'workflow_dispatch' || github.event_name == 'schedule' 39 | steps: 40 | # --- Authenticating if using an LLM from AWS Bedrock --- 41 | # Option A: Configure AWS Credentials using IAM (Recommended for production) 42 | # - name: Configure AWS Credentials 43 | # uses: aws-actions/configure-aws-credentials@v1 44 | # with: 45 | # aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 46 | # aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 47 | # aws-session-token: ${{ secrets.AWS_SESSION_TOKEN }} 48 | # aws-region: ${{ vars.AWS_REGION }} 49 | 50 | # Option B: Use Bedrock API Keys (Simpler but less secure) 51 | # If using API keys, omit the "Configure AWS Credentials" step above 52 | # and uncomment the aws_bearer_token_bedrock and aws_region lines below in the action inputs 53 | 54 | - name: Checkout repository 55 | uses: actions/checkout@v4 56 | with: 57 | fetch-depth: 0 58 | 59 | - name: Run Contrast AI SmartFix - Generate Fixes Action 60 | uses: Contrast-Security-OSS/contrast-ai-smartfix-action@v1 61 | with: 62 | # --- Max Open PRs --- 63 | max_open_prs: 5 64 | # --- Base Branch --- 65 | base_branch: '${{ github.event.repository.default_branch }}' 66 | # --- Build Command --- 67 | build_command: 'mvn clean test' 68 | # --- Formatting Command --- 69 | formatting_command: 'mvn spotless:apply' 70 | # --- Max QA Intervention loop attempts --- 71 | max_qa_attempts: 6 72 | # --- GitHub Token --- 73 | github_token: ${{ secrets.GITHUB_TOKEN }} 74 | # --- Contrast API Credentials --- 75 | contrast_host: ${{ vars.CONTRAST_HOST }} 76 | contrast_org_id: ${{ vars.CONTRAST_ORG_ID }} 77 | contrast_app_id: ${{ vars.CONTRAST_APP_ID }} 78 | contrast_authorization_key: ${{ secrets.CONTRAST_AUTHORIZATION_KEY }} 79 | contrast_api_key: ${{ secrets.CONTRAST_API_KEY }} 80 | # --- Use Contrast LLM as your default LLM provider 81 | use_contrast_llm: 'true' # set to 'false' if using your own LLM 82 | # --- Google Gemini API Credentials --- 83 | # gemini_api_key: ${{ secrets.GEMINI_API_KEY }} 84 | # --- Anthropic API Credentials --- 85 | # anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} 86 | # --- AWS Bedrock API Key (Alternative to IAM credentials, less secure) --- 87 | # aws_bearer_token_bedrock: ${{ secrets.AWS_BEARER_TOKEN_BEDROCK }} 88 | # aws_region: ${{ vars.AWS_REGION }} 89 | # --- Azure Open API Credentials --- 90 | # azure_api_key: ${{ secrets.AZURE_API_KEY }} 91 | # azure_api_base: ${{ secrets.AZURE_API_BASE }} 92 | # azure_api_version: ${{ secrets.AZURE_API_VERSION }} 93 | # --- Agent Configuration (Required if not using Contrast LLM) --- 94 | agent_model: ${{ vars.AGENT_MODEL || 'bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0' }} 95 | # Other Optional Inputs (see action.yml for defaults and more options) 96 | # formatting_command: 'mvn spotless:apply' # Or the command appropriate for your project to correct the formatting of SmartFix\'s changes. This ensures that SmartFix follows your coding standards. 97 | # max_open_prs: 5 # This is the maximum limit for the number of PRs that SmartFix will have open at single time 98 | # enable_full_telemetry: 'false' # Set to false to disable full telemetry 99 | 100 | handle_pr_merge: 101 | name: Handle PR Merge 102 | runs-on: ubuntu-latest 103 | if: github.event.pull_request.merged == true && contains(join(github.event.pull_request.labels.*.name), 'contrast-vuln-id:VULN-') 104 | steps: 105 | - name: Checkout repository 106 | uses: actions/checkout@v4 107 | with: 108 | fetch-depth: 0 109 | 110 | - name: Notify Contrast on PR Merge 111 | uses: Contrast-Security-OSS/contrast-ai-smartfix-action@v1 112 | with: 113 | run_task: merge 114 | # --- GitHub Token --- 115 | github_token: ${{ secrets.GITHUB_TOKEN }} 116 | # --- Contrast API Credentials --- 117 | contrast_host: ${{ vars.CONTRAST_HOST }} 118 | contrast_org_id: ${{ vars.CONTRAST_ORG_ID }} 119 | contrast_app_id: ${{ vars.CONTRAST_APP_ID }} 120 | contrast_authorization_key: ${{ secrets.CONTRAST_AUTHORIZATION_KEY }} 121 | contrast_api_key: ${{ secrets.CONTRAST_API_KEY }} 122 | env: 123 | GITHUB_EVENT_PATH: ${{ github.event_path }} 124 | 125 | handle_pr_closed: 126 | name: Handle PR Close 127 | runs-on: ubuntu-latest 128 | if: github.event.pull_request.merged == false && contains(join(github.event.pull_request.labels.*.name), 'contrast-vuln-id:VULN-') 129 | steps: 130 | - name: Checkout repository 131 | uses: actions/checkout@v4 132 | with: 133 | fetch-depth: 0 134 | 135 | - name: Notify Contrast on PR Closed 136 | uses: Contrast-Security-OSS/contrast-ai-smartfix-action@v1 137 | with: 138 | run_task: closed 139 | # --- GitHub Token --- 140 | github_token: ${{ secrets.GITHUB_TOKEN }} 141 | # --- Contrast API Credentials --- 142 | contrast_host: ${{ vars.CONTRAST_HOST }} 143 | contrast_org_id: ${{ vars.CONTRAST_ORG_ID }} 144 | contrast_app_id: ${{ vars.CONTRAST_APP_ID }} 145 | contrast_authorization_key: ${{ secrets.CONTRAST_AUTHORIZATION_KEY }} 146 | contrast_api_key: ${{ secrets.CONTRAST_API_KEY }} 147 | env: 148 | GITHUB_EVENT_PATH: ${{ github.event_path }} 149 | -------------------------------------------------------------------------------- /test/test_contrast_api_failures.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | from unittest.mock import patch, MagicMock 23 | import requests 24 | 25 | from src.config import reset_config, get_config 26 | from src import contrast_api 27 | from src.smartfix.shared.failure_categories import FailureCategory 28 | 29 | 30 | class TestContrastApiFailureCategories(unittest.TestCase): 31 | """Tests for the contrast_api failure categories and notification functions""" 32 | 33 | def setUp(self): 34 | """Set up test environment before each test""" 35 | reset_config() 36 | self.config = get_config() 37 | 38 | def tearDown(self): 39 | """Clean up after each test""" 40 | reset_config() 41 | 42 | def test_failure_category_enum_all_values(self): 43 | """Test that all expected failure categories are present""" 44 | expected_categories = [ 45 | "INITIAL_BUILD_FAILURE", 46 | "EXCEEDED_QA_ATTEMPTS", 47 | "QA_AGENT_FAILURE", 48 | "GIT_COMMAND_FAILURE", 49 | "AGENT_FAILURE", 50 | "GENERATE_PR_FAILURE", 51 | "GENERAL_FAILURE", 52 | "EXCEEDED_TIMEOUT", 53 | "EXCEEDED_AGENT_EVENTS", 54 | "INVALID_LLM_CONFIG" 55 | ] 56 | 57 | actual_categories = [category.value for category in FailureCategory] 58 | self.assertEqual(set(expected_categories), set(actual_categories)) 59 | 60 | @patch('src.contrast_api.requests.put') 61 | def test_notify_remediation_failed_generate_pr_failure(self, mock_put): 62 | """Test notify_remediation_failed with GENERATE_PR_FAILURE category""" 63 | # Mock successful response 64 | mock_response = MagicMock() 65 | mock_response.status_code = 204 66 | mock_response.raise_for_status.return_value = None 67 | mock_put.return_value = mock_response 68 | 69 | result = contrast_api.notify_remediation_failed( 70 | remediation_id="test-remediation-123", 71 | failure_category="GENERATE_PR_FAILURE", 72 | contrast_host="test.contrastsecurity.com", 73 | contrast_org_id="test-org-id", 74 | contrast_app_id="test-app-id", 75 | contrast_auth_key="test-auth-key", 76 | contrast_api_key="test-api-key" 77 | ) 78 | 79 | self.assertTrue(result) 80 | 81 | # Verify the API call 82 | expected_url = "https://test.contrastsecurity.com/api/v4/aiml-remediation/organizations/test-org-id/applications/test-app-id/remediations/test-remediation-123/failed" 83 | expected_headers = { 84 | "Authorization": "test-auth-key", 85 | "API-Key": "test-api-key", 86 | "Content-Type": "application/json", 87 | "Accept": "application/json", 88 | "User-Agent": self.config.USER_AGENT 89 | } 90 | expected_payload = { 91 | "failureCategory": "GENERATE_PR_FAILURE" 92 | } 93 | 94 | mock_put.assert_called_once_with( 95 | expected_url, 96 | headers=expected_headers, 97 | json=expected_payload 98 | ) 99 | 100 | @patch('src.contrast_api.requests.put') 101 | def test_notify_remediation_failed_http_error(self, mock_put): 102 | """Test notify_remediation_failed when HTTP error occurs""" 103 | # Mock HTTP error response 104 | mock_response = MagicMock() 105 | mock_response.status_code = 500 106 | mock_response.text = "Internal Server Error" 107 | 108 | # Create a proper HTTPError with response attribute 109 | http_error = requests.exceptions.HTTPError("HTTP Error") 110 | http_error.response = mock_response 111 | mock_response.raise_for_status.side_effect = http_error 112 | mock_put.return_value = mock_response 113 | 114 | result = contrast_api.notify_remediation_failed( 115 | remediation_id="test-remediation-123", 116 | failure_category="GENERATE_PR_FAILURE", 117 | contrast_host="test.contrastsecurity.com", 118 | contrast_org_id="test-org-id", 119 | contrast_app_id="test-app-id", 120 | contrast_auth_key="test-auth-key", 121 | contrast_api_key="test-api-key" 122 | ) 123 | 124 | self.assertFalse(result) 125 | 126 | @patch('src.contrast_api.requests.put') 127 | def test_notify_remediation_failed_non_204_response(self, mock_put): 128 | """Test notify_remediation_failed when API returns non-204 status""" 129 | # Mock non-204 response 130 | mock_response = MagicMock() 131 | mock_response.status_code = 400 132 | mock_response.raise_for_status.return_value = None 133 | mock_response.json.return_value = {"messages": ["Bad request"]} 134 | mock_put.return_value = mock_response 135 | 136 | result = contrast_api.notify_remediation_failed( 137 | remediation_id="test-remediation-123", 138 | failure_category="GENERATE_PR_FAILURE", 139 | contrast_host="test.contrastsecurity.com", 140 | contrast_org_id="test-org-id", 141 | contrast_app_id="test-app-id", 142 | contrast_auth_key="test-auth-key", 143 | contrast_api_key="test-api-key" 144 | ) 145 | 146 | self.assertFalse(result) 147 | 148 | def test_normalize_host_removes_https(self): 149 | """Test that normalize_host properly removes https prefix""" 150 | result = contrast_api.normalize_host("https://test.contrastsecurity.com") 151 | self.assertEqual(result, "test.contrastsecurity.com") 152 | 153 | def test_normalize_host_removes_http(self): 154 | """Test that normalize_host properly removes http prefix""" 155 | result = contrast_api.normalize_host("http://test.contrastsecurity.com") 156 | self.assertEqual(result, "test.contrastsecurity.com") 157 | 158 | def test_normalize_host_no_prefix(self): 159 | """Test that normalize_host leaves host unchanged when no prefix""" 160 | result = contrast_api.normalize_host("test.contrastsecurity.com") 161 | self.assertEqual(result, "test.contrastsecurity.com") 162 | 163 | 164 | if __name__ == '__main__': 165 | unittest.main() 166 | -------------------------------------------------------------------------------- /test/test_config_integration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | import os 5 | from unittest.mock import patch 6 | from src.config import get_config, reset_config 7 | 8 | 9 | class TestConfigIntegration(unittest.TestCase): 10 | """Integration tests for configuration settings including USE_CONTRAST_LLM.""" 11 | 12 | def setUp(self): 13 | """Set up test environment before each test.""" 14 | # Store original environment to restore later 15 | self.original_env = os.environ.copy() 16 | 17 | # Set up minimal required environment variables for testing 18 | self.env_vars = { 19 | 'GITHUB_WORKSPACE': '/tmp', 20 | 'BUILD_COMMAND': 'echo "Mock build"', 21 | 'GITHUB_TOKEN': 'mock-token', 22 | 'GITHUB_REPOSITORY': 'mock/repo', 23 | 'BASE_BRANCH': 'main', 24 | 'CONTRAST_HOST': 'test.contrastsecurity.com', 25 | 'CONTRAST_ORG_ID': 'test-org-id', 26 | 'CONTRAST_APP_ID': 'test-app-id', 27 | 'CONTRAST_AUTHORIZATION_KEY': 'test-auth-key', 28 | 'CONTRAST_API_KEY': 'test-api-key' 29 | } 30 | 31 | os.environ.update(self.env_vars) 32 | reset_config() 33 | 34 | def tearDown(self): 35 | """Clean up after each test.""" 36 | # Restore original environment 37 | os.environ.clear() 38 | os.environ.update(self.original_env) 39 | reset_config() 40 | 41 | def test_contrast_llm_true_with_agent_model_config(self): 42 | """Test that USE_CONTRAST_LLM=True works with agent model configuration.""" 43 | os.environ['USE_CONTRAST_LLM'] = 'true' 44 | os.environ['AGENT_MODEL'] = 'bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0' 45 | reset_config() 46 | 47 | config = get_config(testing=True) 48 | 49 | # Both settings should be available 50 | self.assertTrue(config.USE_CONTRAST_LLM) 51 | self.assertEqual(config.AGENT_MODEL, 'bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0') 52 | 53 | def test_enable_anthropic_prompt_caching_default(self): 54 | """Test that ENABLE_ANTHROPIC_PROMPT_CACHING defaults to True.""" 55 | reset_config() 56 | config = get_config(testing=True) 57 | self.assertTrue(config.ENABLE_ANTHROPIC_PROMPT_CACHING) 58 | 59 | def test_enable_anthropic_prompt_caching_false(self): 60 | """Test that ENABLE_ANTHROPIC_PROMPT_CACHING can be set to False.""" 61 | os.environ['ENABLE_ANTHROPIC_PROMPT_CACHING'] = 'false' 62 | reset_config() 63 | config = get_config(testing=True) 64 | self.assertFalse(config.ENABLE_ANTHROPIC_PROMPT_CACHING) 65 | 66 | def test_enable_anthropic_prompt_caching_true(self): 67 | """Test that ENABLE_ANTHROPIC_PROMPT_CACHING can be explicitly set to True.""" 68 | os.environ['ENABLE_ANTHROPIC_PROMPT_CACHING'] = 'true' 69 | reset_config() 70 | config = get_config(testing=True) 71 | self.assertTrue(config.ENABLE_ANTHROPIC_PROMPT_CACHING) 72 | 73 | def test_contrast_llm_false_requires_agent_model(self): 74 | """Test that USE_CONTRAST_LLM=False works when AGENT_MODEL is configured.""" 75 | os.environ['USE_CONTRAST_LLM'] = 'false' 76 | os.environ['AGENT_MODEL'] = 'anthropic/claude-sonnet-4-5-20250929' 77 | os.environ['ANTHROPIC_API_KEY'] = 'test-key' 78 | reset_config() 79 | 80 | config = get_config(testing=True) 81 | 82 | # Should be configured for BYOLLM 83 | self.assertFalse(config.USE_CONTRAST_LLM) 84 | self.assertEqual(config.AGENT_MODEL, 'anthropic/claude-sonnet-4-5-20250929') 85 | 86 | def test_coding_agent_smartfix_with_contrast_llm(self): 87 | """Test that SMARTFIX coding agent works with USE_CONTRAST_LLM=True.""" 88 | os.environ['CODING_AGENT'] = 'SMARTFIX' 89 | os.environ['USE_CONTRAST_LLM'] = 'true' 90 | reset_config() 91 | 92 | config = get_config(testing=True) 93 | 94 | # Should have SMARTFIX agent with Contrast LLM enabled 95 | self.assertEqual(config.CODING_AGENT, 'SMARTFIX') 96 | self.assertTrue(config.USE_CONTRAST_LLM) 97 | # Default agent model should use Contrast LLM constant 98 | self.assertEqual(config.AGENT_MODEL, 'contrast/claude-sonnet-4-5') 99 | 100 | def test_coding_agent_smartfix_with_byollm(self): 101 | """Test that SMARTFIX coding agent works with USE_CONTRAST_LLM=False (BYOLLM).""" 102 | os.environ['CODING_AGENT'] = 'SMARTFIX' 103 | os.environ['USE_CONTRAST_LLM'] = 'false' 104 | os.environ['AGENT_MODEL'] = 'anthropic/claude-sonnet-4-5-20250929' 105 | reset_config() 106 | 107 | config = get_config(testing=True) 108 | 109 | # Should have SMARTFIX agent with BYOLLM configured 110 | self.assertEqual(config.CODING_AGENT, 'SMARTFIX') 111 | self.assertFalse(config.USE_CONTRAST_LLM) 112 | self.assertEqual(config.AGENT_MODEL, 'anthropic/claude-sonnet-4-5-20250929') 113 | 114 | def test_debug_mode_shows_contrast_llm_setting(self): 115 | """Test that DEBUG_MODE includes USE_CONTRAST_LLM in logging.""" 116 | os.environ['DEBUG_MODE'] = 'true' 117 | os.environ['USE_CONTRAST_LLM'] = 'false' 118 | reset_config() 119 | 120 | with patch('src.config._log_config_message') as mock_log: 121 | config = get_config(testing=False) # Use testing=False to trigger debug logging 122 | 123 | # Verify the config values are correct 124 | self.assertTrue(config.DEBUG_MODE) 125 | self.assertFalse(config.USE_CONTRAST_LLM) 126 | 127 | # Check that logging was called with our setting 128 | log_calls = [call.args[0] for call in mock_log.call_args_list] 129 | contrast_llm_logged = any('Use Contrast LLM: False' in call for call in log_calls) 130 | self.assertTrue(contrast_llm_logged, 131 | f"Expected 'Use Contrast LLM: False' in debug logs. Got: {log_calls}") 132 | 133 | def test_config_singleton_behavior_with_contrast_llm(self): 134 | """Test that config singleton properly handles USE_CONTRAST_LLM changes.""" 135 | # First config instance 136 | os.environ['USE_CONTRAST_LLM'] = 'true' 137 | reset_config() 138 | config1 = get_config(testing=True) 139 | self.assertTrue(config1.USE_CONTRAST_LLM) 140 | 141 | # Reset and create new config with different value 142 | os.environ['USE_CONTRAST_LLM'] = 'false' 143 | reset_config() 144 | config2 = get_config(testing=True) 145 | self.assertFalse(config2.USE_CONTRAST_LLM) 146 | 147 | # Verify they're different instances due to reset 148 | self.assertNotEqual(id(config1), id(config2)) 149 | 150 | def test_all_feature_flags_work_together(self): 151 | """Test that USE_CONTRAST_LLM works alongside other feature flags.""" 152 | os.environ['USE_CONTRAST_LLM'] = 'true' 153 | os.environ['SKIP_QA_REVIEW'] = 'false' 154 | os.environ['SKIP_WRITING_SECURITY_TEST'] = 'true' 155 | os.environ['ENABLE_FULL_TELEMETRY'] = 'false' 156 | reset_config() 157 | 158 | config = get_config(testing=True) 159 | 160 | # Verify all feature flags are set correctly 161 | self.assertTrue(config.USE_CONTRAST_LLM) 162 | self.assertFalse(config.SKIP_QA_REVIEW) 163 | self.assertTrue(config.SKIP_WRITING_SECURITY_TEST) 164 | self.assertFalse(config.ENABLE_FULL_TELEMETRY) 165 | 166 | 167 | if __name__ == '__main__': 168 | unittest.main() 169 | -------------------------------------------------------------------------------- /src/merge_handler.py: -------------------------------------------------------------------------------- 1 | # - 2 | # #%L 3 | # Contrast AI SmartFix 4 | # %% 5 | # Copyright (C) 2025 Contrast Security, Inc. 6 | # %% 7 | # Contact: support@contrastsecurity.com 8 | # License: Commercial 9 | # NOTICE: This Software and the patented inventions embodied within may only be 10 | # used as part of Contrast Security’s commercial offerings. Even though it is 11 | # made available through public repositories, use of this Software is subject to 12 | # the applicable End User Licensing Agreement found at 13 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 14 | # between Contrast Security and the End User. The Software may not be reverse 15 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 16 | # way not consistent with the End User License Agreement. 17 | # #L% 18 | # 19 | 20 | import os 21 | import json 22 | import sys 23 | 24 | # Import from src package to ensure correct module resolution 25 | from src import contrast_api 26 | from src.config import get_config # Using get_config function instead of direct import 27 | from src.utils import debug_log, extract_remediation_id_from_branch, extract_remediation_id_from_labels, log 28 | from src.git_handler import extract_issue_number_from_branch 29 | import src.telemetry_handler as telemetry_handler 30 | 31 | 32 | def _load_github_event() -> dict: 33 | """Load and parse the GitHub event data.""" 34 | event_path = os.getenv("GITHUB_EVENT_PATH") 35 | if not event_path: 36 | log("Error: GITHUB_EVENT_PATH not set. Cannot process PR event.", is_error=True) 37 | sys.exit(1) 38 | 39 | try: 40 | with open(event_path, 'r') as f: 41 | return json.load(f) 42 | except Exception as e: 43 | log(f"Error reading or parsing GITHUB_EVENT_PATH file: {e}", is_error=True) 44 | sys.exit(1) 45 | 46 | 47 | def _validate_pr_event(event_data: dict) -> dict: 48 | """Validate the PR event and return PR data.""" 49 | if event_data.get("action") != "closed": 50 | log("PR action is not 'closed'. Skipping.") 51 | sys.exit(0) 52 | 53 | pull_request = event_data.get("pull_request", {}) 54 | if not pull_request.get("merged"): 55 | log("PR was closed but not merged. Skipping.") 56 | sys.exit(0) 57 | 58 | debug_log("Pull request was merged.") 59 | return pull_request 60 | 61 | 62 | def _extract_remediation_info(pull_request: dict) -> tuple: 63 | """Extract remediation ID and other info from PR data.""" 64 | branch_name = pull_request.get("head", {}).get("ref") 65 | if not branch_name: 66 | log("Error: Could not determine branch name from PR.", is_error=True) 67 | sys.exit(1) 68 | 69 | debug_log(f"Branch name: {branch_name}") 70 | labels = pull_request.get("labels", []) 71 | 72 | # Extract remediation ID from branch name or PR labels 73 | remediation_id = None 74 | 75 | # Check if this is a branch created by external agent (e.g., GitHub Copilot or Claude Code) 76 | if branch_name.startswith("copilot/fix") or branch_name.startswith("claude/issue-"): 77 | debug_log("Branch appears to be created by external agent. Extracting remediation ID from PR labels.") 78 | remediation_id = extract_remediation_id_from_labels(labels) 79 | # Extract GitHub issue number from branch name 80 | issue_number = extract_issue_number_from_branch(branch_name) 81 | if issue_number: 82 | telemetry_handler.update_telemetry("additionalAttributes.externalIssueNumber", issue_number) 83 | debug_log(f"Extracted external issue number from branch name: {issue_number}") 84 | else: 85 | debug_log(f"Could not extract issue number from branch name: {branch_name}") 86 | 87 | # Set the external coding agent in telemetry based on branch prefix 88 | coding_agent = "EXTERNAL-CLAUDE-CODE" if branch_name.startswith("claude/") else "EXTERNAL-COPILOT" 89 | debug_log(f"Determined external coding agent to be: {coding_agent}") 90 | telemetry_handler.update_telemetry("additionalAttributes.codingAgent", coding_agent) 91 | else: 92 | # Use original method for branches created by SmartFix 93 | remediation_id = extract_remediation_id_from_branch(branch_name) 94 | telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "INTERNAL-SMARTFIX") 95 | 96 | if not remediation_id: 97 | if branch_name.startswith("copilot/fix") or branch_name.startswith("claude/issue-"): 98 | log(f"Error: Could not extract remediation ID from PR labels for external agent branch: {branch_name}", is_error=True) 99 | else: 100 | log(f"Error: Could not extract remediation ID from branch name: {branch_name}", is_error=True) 101 | sys.exit(1) 102 | 103 | return remediation_id, labels 104 | 105 | 106 | def _extract_vulnerability_info(labels: list) -> str: 107 | """Extract vulnerability UUID from PR labels.""" 108 | vuln_uuid = "unknown" 109 | 110 | for label in labels: 111 | label_name = label.get("name", "") 112 | if label_name.startswith("contrast-vuln-id:VULN-"): 113 | # Extract UUID from label format "contrast-vuln-id:VULN-{vuln_uuid}" 114 | label_name_parts = label_name.split("VULN-") 115 | vuln_uuid = label_name_parts[1] if len(label_name_parts) > 1 else "unknown" 116 | if vuln_uuid and vuln_uuid != "unknown": 117 | debug_log(f"Extracted Vulnerability UUID from PR label: {vuln_uuid}") 118 | break 119 | 120 | if vuln_uuid == "unknown": 121 | debug_log("Could not extract vulnerability UUID from PR labels. Telemetry may be incomplete.") 122 | 123 | return vuln_uuid 124 | 125 | 126 | def _notify_remediation_service(remediation_id: str): 127 | """Notify the Remediation backend service about the merged PR.""" 128 | log(f"Notifying Remediation service about merged PR for remediation {remediation_id}...") 129 | config = get_config() 130 | remediation_notified = contrast_api.notify_remediation_pr_merged( 131 | remediation_id=remediation_id, 132 | contrast_host=config.CONTRAST_HOST, 133 | contrast_org_id=config.CONTRAST_ORG_ID, 134 | contrast_app_id=config.CONTRAST_APP_ID, 135 | contrast_auth_key=config.CONTRAST_AUTHORIZATION_KEY, 136 | contrast_api_key=config.CONTRAST_API_KEY 137 | ) 138 | 139 | if remediation_notified: 140 | log(f"Successfully notified Remediation service about merged PR for remediation {remediation_id}.") 141 | else: 142 | log(f"Failed to notify Remediation service about merged PR for remediation {remediation_id}.", is_error=True) 143 | 144 | 145 | def handle_merged_pr(): 146 | """Handles the logic when a pull request is merged.""" 147 | telemetry_handler.initialize_telemetry() 148 | 149 | log("--- Handling Merged Contrast AI SmartFix Pull Request ---") 150 | 151 | # Load and validate GitHub event data 152 | event_data = _load_github_event() 153 | pull_request = _validate_pr_event(event_data) 154 | 155 | # Extract remediation and vulnerability information 156 | remediation_id, labels = _extract_remediation_info(pull_request) 157 | vuln_uuid = _extract_vulnerability_info(labels) 158 | 159 | # Update telemetry with extracted information 160 | debug_log(f"Extracted Remediation ID: {remediation_id}") 161 | telemetry_handler.update_telemetry("additionalAttributes.remediationId", remediation_id) 162 | telemetry_handler.update_telemetry("vulnInfo.vulnId", vuln_uuid) 163 | telemetry_handler.update_telemetry("vulnInfo.vulnRule", "unknown") 164 | 165 | # Notify the Remediation backend service 166 | _notify_remediation_service(remediation_id) 167 | 168 | # Complete telemetry and finish 169 | telemetry_handler.update_telemetry("additionalAttributes.prStatus", "MERGED") 170 | contrast_api.send_telemetry_data() 171 | 172 | log("--- Merged Contrast AI SmartFix Pull Request Handling Complete ---") 173 | 174 | 175 | if __name__ == "__main__": 176 | handle_merged_pr() 177 | -------------------------------------------------------------------------------- /test/test_main.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import io 4 | import contextlib 5 | from unittest.mock import patch, MagicMock 6 | 7 | # Test setup imports (path is set up by conftest.py) 8 | from setup_test_env import create_temp_repo_dir 9 | from src.config import reset_config 10 | from src.main import main 11 | 12 | 13 | class TestMain(unittest.TestCase): 14 | """Test the main functionality of the application.""" 15 | 16 | def setUp(self): 17 | """Set up test environment before each test.""" 18 | # Use helper for temp directory creation 19 | self.temp_dir = str(create_temp_repo_dir()) 20 | 21 | # Setup standard env vars, then override paths for this test 22 | # Override paths specific to this test 23 | import os 24 | self.env_vars = { 25 | 'HOME': self.temp_dir, 26 | 'GITHUB_WORKSPACE': self.temp_dir, 27 | 'BUILD_COMMAND': 'echo "Mock build"', 28 | 'FORMATTING_COMMAND': 'echo "Mock format"', 29 | 'GITHUB_TOKEN': 'mock-token', 30 | 'GITHUB_REPOSITORY': 'mock/repo', 31 | 'GITHUB_SERVER_URL': 'https://mockhub.com', 32 | 'CONTRAST_HOST': 'mock.contrastsecurity.com', # No https:// prefix 33 | 'CONTRAST_ORG_ID': 'mock-org', 34 | 'CONTRAST_APP_ID': 'mock-app', 35 | 'CONTRAST_AUTHORIZATION_KEY': 'mock-auth', 36 | 'CONTRAST_API_KEY': 'mock-api', 37 | 'BASE_BRANCH': 'main', 38 | 'DEBUG_MODE': 'true', 39 | 'RUN_TASK': 'generate_fix' 40 | } 41 | 42 | # Apply additional environment variables to what the mixin already set up 43 | os.environ.update(self.env_vars) 44 | 45 | # Reset config for clean test state 46 | reset_config() 47 | 48 | # Mock subprocess calls 49 | self.subproc_patcher = patch('subprocess.run') 50 | self.mock_subprocess = self.subproc_patcher.start() 51 | mock_process = MagicMock() 52 | mock_process.returncode = 0 53 | mock_process.stdout = "Mock output" 54 | mock_process.communicate.return_value = (b"Mock stdout", b"Mock stderr") 55 | self.mock_subprocess.return_value = mock_process 56 | 57 | # Mock git configuration 58 | self.git_patcher = patch('src.git_handler.configure_git_user') 59 | self.mock_git = self.git_patcher.start() 60 | 61 | # Mock API calls 62 | self.api_patcher = patch('src.contrast_api.get_vulnerability_with_prompts') 63 | self.mock_api = self.api_patcher.start() 64 | self.mock_api.return_value = None 65 | 66 | # Mock requests for version checking 67 | self.requests_patcher = patch('src.version_check.requests.get') 68 | self.mock_requests_get = self.requests_patcher.start() 69 | mock_response = MagicMock() 70 | mock_response.json.return_value = [{'name': 'v1.0.0'}] 71 | mock_response.raise_for_status.return_value = None 72 | self.mock_requests_get.return_value = mock_response 73 | 74 | # Mock sys.exit to prevent test termination 75 | self.exit_patcher = patch('sys.exit') 76 | self.mock_exit = self.exit_patcher.start() 77 | 78 | def tearDown(self): 79 | """Clean up after each test.""" 80 | # Stop all patches 81 | self.subproc_patcher.stop() 82 | self.git_patcher.stop() 83 | self.api_patcher.stop() 84 | self.requests_patcher.stop() 85 | self.exit_patcher.stop() 86 | reset_config() 87 | 88 | # Clean up temp directory 89 | if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): 90 | import shutil 91 | shutil.rmtree(self.temp_dir) 92 | 93 | def test_main_with_version_check(self): 94 | """Test main function with version check.""" 95 | # Add version ref to environment 96 | updated_env = self.env_vars.copy() 97 | updated_env['GITHUB_ACTION_REF'] = 'refs/tags/v1.0.0' 98 | 99 | # Create a proper patch for the function as imported in main.py 100 | # Note: main.py imports from version_check directly, not src.version_check 101 | with patch('src.version_check.get_latest_repo_version') as mock_get_latest: 102 | # Setup version check mocks 103 | mock_get_latest.return_value = "v1.0.0" 104 | 105 | with patch.dict('os.environ', updated_env, clear=True): 106 | # Run main and capture output 107 | with io.StringIO() as buf, contextlib.redirect_stdout(buf): 108 | main() 109 | output = buf.getvalue() 110 | 111 | # Verify main function and version check ran 112 | self.assertIn("--- Starting Contrast AI SmartFix Script ---", output) 113 | self.assertIn("Current action version", output) 114 | mock_get_latest.assert_called_once() 115 | 116 | def test_main_without_action_ref(self): 117 | """Test main function without GITHUB_ACTION_REF.""" 118 | # Ensure no GITHUB_ACTION_REF is set 119 | test_env = self.env_vars.copy() 120 | if 'GITHUB_ACTION_REF' in test_env: 121 | del test_env['GITHUB_ACTION_REF'] 122 | if 'GITHUB_REF' in test_env: 123 | del test_env['GITHUB_REF'] 124 | 125 | with patch.dict('os.environ', test_env, clear=True): 126 | # Run main and capture output 127 | with io.StringIO() as buf, contextlib.redirect_stdout(buf): 128 | main() 129 | output = buf.getvalue() 130 | 131 | # Verify warning about missing environment variables is present (updated for new message format) 132 | self.assertIn("Warning: Neither GITHUB_ACTION_REF nor GITHUB_REF environment variables are set", output) 133 | 134 | def test_duplicate_vuln_with_open_pr_skips_cleanly(self): 135 | """Test that duplicate vulnerability UUID with open PR skips cleanly without error_exit. 136 | 137 | Regression test for AIML-241: When the API returns the same vulnerability twice 138 | (common when a PR is already open), the code should skip it cleanly using the 139 | skipped_vulns logic rather than triggering the duplicate guard error. 140 | """ 141 | # Setup: Mock API to return same vulnerability twice, then None 142 | vuln_data = { 143 | 'vulnerabilityUuid': 'TEST-VULN-UUID-123', 144 | 'vulnerabilityTitle': 'Test SQL Injection', 145 | 'vulnerabilityRuleName': 'sql-injection', 146 | 'remediationId': 'REM-TEST-123', 147 | 'sessionId': 'session-123', 148 | 'fixSystemPrompt': 'Fix the vulnerability', 149 | 'fixUserPrompt': 'Please fix', 150 | 'qaSystemPrompt': 'Review the fix', 151 | 'qaUserPrompt': 'Is it good?' 152 | } 153 | 154 | # Return same vuln twice, then None to stop loop 155 | self.mock_api.side_effect = [vuln_data, vuln_data, None] 156 | 157 | # Mock PR status check to return OPEN (simulating existing PR) 158 | with patch('src.git_handler.check_pr_status_for_label') as mock_pr_check: 159 | mock_pr_check.return_value = "OPEN" 160 | 161 | # Mock generate_label_details 162 | with patch('src.git_handler.generate_label_details') as mock_label: 163 | mock_label.return_value = ('contrast-vuln-id:TEST-VULN-UUID-123', 'color', 'desc') 164 | 165 | with patch.dict('os.environ', self.env_vars, clear=True): 166 | # Run main and capture output 167 | with io.StringIO() as buf, contextlib.redirect_stdout(buf): 168 | main() 169 | output = buf.getvalue() 170 | 171 | # Verify the vulnerability was skipped both times 172 | self.assertIn("Skipping vulnerability TEST-VULN-UUID-123", output) 173 | self.assertIn("Already skipped TEST-VULN-UUID-123 before, breaking loop", output) 174 | 175 | # Verify the loop broke cleanly 176 | self.assertIn("No vulnerabilities were processed in this run", output) 177 | 178 | 179 | if __name__ == '__main__': 180 | unittest.main() 181 | -------------------------------------------------------------------------------- /src/version_check.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from packaging.version import parse as parse_version, Version 4 | from src.utils import debug_log, log 5 | from src.config import get_config 6 | config = get_config() 7 | 8 | HEX_CHARS = "0123456789abcdef" 9 | ACTION_REPO_URL = "https://github.com/Contrast-Security-OSS/contrast-ai-smartfix-action" 10 | 11 | 12 | def normalize_version(version_str: str) -> str: 13 | """Normalize a version string for comparison by removing 'v' prefix.""" 14 | if version_str and version_str.startswith('v'): 15 | return version_str[1:] 16 | return version_str 17 | 18 | 19 | def safe_parse_version(version_str: str) -> Version: 20 | """Safely parse a version string, handling exceptions.""" 21 | try: 22 | return parse_version(normalize_version(version_str)) 23 | except Exception: 24 | return None 25 | 26 | 27 | def get_latest_repo_version(repo_url: str): 28 | """Fetches the latest release tag from a GitHub repository.""" 29 | try: 30 | # Clean the repo URL to avoid double https:// issues 31 | cleaned_repo_path = repo_url.replace('https://github.com/', '') 32 | if cleaned_repo_path == repo_url: 33 | # If no substitution happened, ensure we're using the correct format 34 | cleaned_repo_path = repo_url.replace('github.com/', '') 35 | 36 | # Construct the API URL for tags 37 | api_url = f"https://api.github.com/repos/{cleaned_repo_path}/tags" 38 | debug_log(f"Fetching tags from: {api_url}") 39 | response = requests.get(api_url) 40 | response.raise_for_status() # Raise an exception for HTTP errors 41 | tags = response.json() 42 | 43 | if not tags: 44 | debug_log("No tags found in the repository.") 45 | return None 46 | 47 | valid_tags = [] 48 | for tag in tags: 49 | try: 50 | version_str = tag['name'] 51 | if version_str.startswith('v'): 52 | version_str = version_str[1:] 53 | parse_version(version_str) # Check if it's a valid version 54 | valid_tags.append(tag['name']) 55 | except Exception: 56 | # Ignore tags that are not valid versions 57 | debug_log(f"Ignoring invalid version tag: {tag.get('name', 'unknown')}") 58 | pass 59 | 60 | if not valid_tags: 61 | debug_log("No valid version tags found in the repository.") 62 | return None 63 | 64 | # Sort valid tags to find the latest 65 | valid_tags.sort(key=lambda v: parse_version(v.lstrip('v')), reverse=True) 66 | debug_log(f"Latest version found: {valid_tags[0]}") 67 | return valid_tags[0] 68 | except requests.exceptions.RequestException as e: 69 | debug_log(f"Error fetching tags: {e}") 70 | return None 71 | except Exception as e: 72 | debug_log(f"An unexpected error occurred while fetching tags: {e}") 73 | return None 74 | 75 | 76 | def check_for_newer_version(current_version, latest_version_str: str): 77 | """Compares the current version with the latest version. 78 | Returns the latest_version_str if it's newer, otherwise None. 79 | 80 | Args: 81 | current_version: Either a string version or a Version object 82 | latest_version_str: String representation of the latest version 83 | 84 | Returns: 85 | The latest_version_str if newer, otherwise None 86 | """ 87 | original_latest_version_str = latest_version_str # Store the original 88 | 89 | try: 90 | # Handle the case where current_version is already a Version object 91 | if hasattr(current_version, 'release'): 92 | current_v = current_version 93 | else: 94 | # Handle string version 95 | current_v = parse_version(normalize_version(current_version)) 96 | 97 | latest_v = parse_version(normalize_version(latest_version_str)) 98 | 99 | debug_log(f"Comparing versions: current={current_v} latest={latest_v}") 100 | if latest_v > current_v: 101 | debug_log("Newer version detected") 102 | return original_latest_version_str # Return the original string 103 | debug_log("No newer version found") 104 | return None 105 | except Exception as e: 106 | debug_log(f"Error parsing versions for comparison: {current_version}, {latest_version_str} - {e}") 107 | return None 108 | 109 | 110 | def do_version_check(): 111 | """ 112 | Orchestrates the version check: 113 | 1. Determines current version from either environment variables or hardcoded constant. 114 | 2. Fetches the latest version from the repository. 115 | 3. Compares versions and prints a message if a newer version is available. 116 | """ 117 | debug_log("Starting version check") 118 | 119 | # Get environment variables for version checking 120 | github_ref = os.environ.get("GITHUB_REF") 121 | github_action_ref = os.environ.get("GITHUB_ACTION_REF") 122 | github_sha = os.environ.get("GITHUB_SHA") 123 | 124 | debug_log("Available GitHub environment variables for version checking:") 125 | if github_ref: 126 | debug_log(f" GITHUB_REF: {github_ref}") 127 | if github_action_ref: 128 | debug_log(f" GITHUB_ACTION_REF: {github_action_ref}") 129 | if github_sha: 130 | debug_log(f" GITHUB_SHA: {github_sha}") 131 | 132 | # In production, use the hardcoded version constant 133 | current_action_version = config.VERSION 134 | debug_log(f"Using hardcoded action version: {current_action_version}") 135 | 136 | # For test compatibility: 137 | 138 | # No reference found - log appropriate message for tests 139 | if not github_action_ref and not github_ref: 140 | debug_log("Warning: Neither GITHUB_ACTION_REF nor GITHUB_REF environment variables are set. Version checking is skipped.") 141 | 142 | # SHA reference only - log appropriate message for tests 143 | if not github_action_ref and not github_ref and github_sha: 144 | debug_log(f"Running from SHA: {github_sha}. No ref found for version check, using SHA.") 145 | 146 | # For SHA references - log appropriate message for tests 147 | if github_action_ref and all(c in HEX_CHARS for c in github_action_ref.lower()): 148 | debug_log(f"Running action from SHA: {github_action_ref}. Skipping version comparison against tags.") 149 | return 150 | 151 | # For branch references - log appropriate message for tests 152 | if github_ref and github_ref.startswith("refs/heads/"): 153 | branch_name = github_ref.replace("refs/heads/", "") 154 | debug_log(f"Running from branch '{branch_name}'. Version checking is only meaningful when using release tags.") 155 | return 156 | 157 | # Support version detection from refs for tests 158 | # Use ref_version for the actual version from tags when available 159 | ref_version = None 160 | if github_action_ref and github_action_ref.startswith("refs/tags/v"): 161 | ref_version = github_action_ref.replace("refs/tags/", "") 162 | debug_log(f"Current action version: {ref_version}") 163 | # Use this instead of the hardcoded version for comparison 164 | current_action_version = ref_version 165 | elif github_ref and github_ref.startswith("refs/tags/v"): 166 | ref_version = github_ref.replace("refs/tags/", "") 167 | debug_log(f"Current action version: {ref_version}") 168 | # Use this instead of the hardcoded version for comparison 169 | current_action_version = ref_version 170 | 171 | # Parse the current version 172 | parsed_version = safe_parse_version(current_action_version) 173 | if not parsed_version: 174 | debug_log(f"Warning: Could not parse current action version '{current_action_version}' as a semantic version. Skipping version check.") 175 | return 176 | 177 | # Use original version string for display 178 | parsed_version_str_for_logging = current_action_version 179 | debug_log(f"Current action version: {parsed_version_str_for_logging}") 180 | 181 | # Fetch the latest version from the repository 182 | latest_repo_version = get_latest_repo_version(ACTION_REPO_URL) 183 | 184 | if latest_repo_version: 185 | debug_log(f"Latest version available in repo: {latest_repo_version}") 186 | newer_version = check_for_newer_version(parsed_version, latest_repo_version) 187 | if newer_version: 188 | # Use utils.log for the new version message 189 | log(f"INFO: A newer version of this action is available ({newer_version}).") 190 | log(f"INFO: You are running version {parsed_version_str_for_logging}.") 191 | log(f"INFO: Please update your workflow to use the latest version of the action like this: Contrast-Security-OSS/contrast-ai-smartfix-action@{newer_version}") 192 | else: 193 | debug_log("Could not determine the latest version from the repository.") 194 | -------------------------------------------------------------------------------- /test/test_contrast_message_handling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User License Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackage, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | from unittest.mock import MagicMock 23 | 24 | # Import just the function we need to test, avoiding full class initialization 25 | import sys 26 | from pathlib import Path 27 | project_root = Path(__file__).parent.parent 28 | sys.path.insert(0, str(project_root)) 29 | 30 | 31 | class MockSmartFixLiteLlm: 32 | """Mock version of SmartFixLiteLlm for testing message handling logic""" 33 | 34 | def __init__(self, system_prompt=None): 35 | self._system_prompt = system_prompt 36 | 37 | def _ensure_system_message_for_contrast(self, messages): 38 | """Copy of the actual method for testing""" 39 | system_prompt = self._system_prompt 40 | if not system_prompt: 41 | return messages 42 | 43 | # Check if we have any system message 44 | has_system = False 45 | has_developer = False 46 | 47 | for msg in messages: 48 | if isinstance(msg, dict): 49 | role = msg.get('role') 50 | elif hasattr(msg, 'role'): 51 | role = getattr(msg, 'role') 52 | else: 53 | continue 54 | 55 | if role == 'system': 56 | has_system = True 57 | elif role == 'developer': 58 | has_developer = True 59 | 60 | # For Contrast models, ensure we have system message and remove any developer messages 61 | if not has_system and not has_developer: 62 | system_message = { 63 | 'role': 'system', 64 | 'content': system_prompt 65 | } 66 | messages = [system_message] + list(messages) 67 | elif not has_system and has_developer: 68 | # Add system message with actual prompt 69 | system_message = { 70 | 'role': 'system', 71 | 'content': system_prompt 72 | } 73 | # Add decoy developer message to prevent LiteLLM from moving system message 74 | decoy_developer = { 75 | 'role': 'developer', 76 | 'content': [{'type': 'text', 'text': ''}] 77 | } 78 | 79 | # Filter out original developer messages to avoid duplicates 80 | filtered_messages = [] 81 | for msg in messages: 82 | if isinstance(msg, dict): 83 | role = msg.get('role') 84 | elif hasattr(msg, 'role'): 85 | role = getattr(msg, 'role') 86 | else: 87 | role = None 88 | 89 | # Skip developer messages - we'll use our decoy instead 90 | if role != 'developer': 91 | filtered_messages.append(msg) 92 | 93 | messages = [system_message, decoy_developer] + filtered_messages 94 | 95 | return messages 96 | 97 | 98 | class TestContrastMessageHandling(unittest.TestCase): 99 | """Tests for Contrast-specific message handling logic""" 100 | 101 | def setUp(self): 102 | """Set up test fixtures before each test method.""" 103 | self.system_prompt = "You are a security assistant." 104 | self.model = MockSmartFixLiteLlm(system_prompt=self.system_prompt) 105 | 106 | def test_ensure_system_message_no_system_no_developer(self): 107 | """Test adding system message when no system or developer messages exist""" 108 | messages = [ 109 | {'role': 'user', 'content': 'Hello'} 110 | ] 111 | 112 | result = self.model._ensure_system_message_for_contrast(messages) 113 | 114 | # Should have: system message, original user message 115 | self.assertEqual(len(result), 2) 116 | self.assertEqual(result[0]['role'], 'system') 117 | self.assertEqual(result[0]['content'], self.system_prompt) 118 | self.assertEqual(result[1]['role'], 'user') 119 | self.assertEqual(result[1]['content'], 'Hello') 120 | 121 | def test_ensure_system_message_has_developer_no_system(self): 122 | """Test adding system message when developer exists but no system message""" 123 | messages = [ 124 | {'role': 'developer', 'content': 'Original developer message'}, 125 | {'role': 'user', 'content': 'Hello'} 126 | ] 127 | 128 | result = self.model._ensure_system_message_for_contrast(messages) 129 | 130 | # Should have: system message, decoy developer, user message (original developer filtered out) 131 | self.assertEqual(len(result), 3) 132 | self.assertEqual(result[0]['role'], 'system') 133 | self.assertEqual(result[0]['content'], self.system_prompt) 134 | self.assertEqual(result[1]['role'], 'developer') 135 | self.assertEqual(result[1]['content'], [{'type': 'text', 'text': ''}]) 136 | self.assertEqual(result[2]['role'], 'user') 137 | self.assertEqual(result[2]['content'], 'Hello') 138 | 139 | def test_ensure_system_message_has_system(self): 140 | """Test that existing system message is preserved""" 141 | messages = [ 142 | {'role': 'system', 'content': 'Existing system'}, 143 | {'role': 'user', 'content': 'Hello'} 144 | ] 145 | 146 | result = self.model._ensure_system_message_for_contrast(messages) 147 | 148 | # Should return unchanged messages 149 | self.assertEqual(len(result), 2) 150 | self.assertEqual(result[0]['role'], 'system') 151 | self.assertEqual(result[0]['content'], 'Existing system') 152 | self.assertEqual(result[1]['role'], 'user') 153 | self.assertEqual(result[1]['content'], 'Hello') 154 | 155 | def test_ensure_system_message_filters_multiple_developers(self): 156 | """Test that multiple developer messages are filtered out""" 157 | messages = [ 158 | {'role': 'developer', 'content': 'Dev message 1'}, 159 | {'role': 'developer', 'content': 'Dev message 2'}, 160 | {'role': 'user', 'content': 'Hello'}, 161 | {'role': 'assistant', 'content': 'Response'} 162 | ] 163 | 164 | result = self.model._ensure_system_message_for_contrast(messages) 165 | 166 | # Should have: system message, decoy developer, user message, assistant message 167 | self.assertEqual(len(result), 4) 168 | self.assertEqual(result[0]['role'], 'system') 169 | self.assertEqual(result[1]['role'], 'developer') 170 | self.assertEqual(result[1]['content'], [{'type': 'text', 'text': ''}]) 171 | self.assertEqual(result[2]['role'], 'user') 172 | self.assertEqual(result[3]['role'], 'assistant') 173 | 174 | def test_ensure_system_message_no_system_prompt(self): 175 | """Test behavior when no system prompt is available""" 176 | model = MockSmartFixLiteLlm() # No system prompt 177 | messages = [{'role': 'user', 'content': 'Hello'}] 178 | 179 | result = model._ensure_system_message_for_contrast(messages) 180 | # Should return unchanged messages 181 | self.assertEqual(len(result), 1) 182 | self.assertEqual(result[0]['role'], 'user') 183 | 184 | def test_message_object_handling(self): 185 | """Test handling of message objects (not just dicts)""" 186 | # Create mock message objects 187 | user_message = MagicMock() 188 | user_message.role = 'user' 189 | messages = [user_message] 190 | 191 | result = self.model._ensure_system_message_for_contrast(messages) 192 | 193 | # Should add system message 194 | self.assertEqual(len(result), 2) 195 | self.assertEqual(result[0]['role'], 'system') 196 | # Original object should be preserved 197 | self.assertEqual(result[1], user_message) 198 | 199 | def test_decoy_developer_format(self): 200 | """Test that decoy developer message has correct format""" 201 | messages = [ 202 | {'role': 'developer', 'content': 'Some content'}, 203 | {'role': 'user', 'content': 'Hello'} 204 | ] 205 | 206 | result = self.model._ensure_system_message_for_contrast(messages) 207 | 208 | # Check decoy developer message format 209 | decoy = result[1] 210 | self.assertEqual(decoy['role'], 'developer') 211 | self.assertEqual(decoy['content'], [{'type': 'text', 'text': ''}]) 212 | self.assertIsInstance(decoy['content'], list) 213 | self.assertEqual(len(decoy['content']), 1) 214 | self.assertEqual(decoy['content'][0]['type'], 'text') 215 | self.assertEqual(decoy['content'][0]['text'], '') 216 | 217 | 218 | if __name__ == '__main__': 219 | unittest.main() 220 | -------------------------------------------------------------------------------- /test/test_telemetry_attributes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | import os 23 | from unittest.mock import patch 24 | 25 | # Test setup imports (path is set up by conftest.py) 26 | from src.config import reset_config 27 | from src.telemetry_handler import initialize_telemetry, get_telemetry_data 28 | from src.smartfix.shared.llm_providers import LlmProvider 29 | from src.smartfix.shared.coding_agents import CodingAgents 30 | 31 | 32 | class TestTelemetryAttributes(unittest.TestCase): 33 | """Tests for the new telemetry attributes: llmProvider, agentType, and fullTelemetryEnabled""" 34 | 35 | def setUp(self): 36 | """Set up test environment before each test""" 37 | reset_config() 38 | 39 | # Set required environment variables for testing 40 | self.env_vars = { 41 | 'BASE_BRANCH': 'main', 42 | 'GITHUB_TOKEN': 'test-token', 43 | 'GITHUB_REPOSITORY': 'test/repo', 44 | 'GITHUB_SERVER_URL': 'https://mockhub.com', 45 | 'CONTRAST_HOST': 'test.contrastsecurity.com', 46 | 'CONTRAST_ORG_ID': 'test-org-id', 47 | 'CONTRAST_APP_ID': 'test-app-id', 48 | 'CONTRAST_AUTHORIZATION_KEY': 'test-auth-key', 49 | 'CONTRAST_API_KEY': 'test-api-key', 50 | 'BUILD_COMMAND': 'echo test', 51 | 'GITHUB_WORKSPACE': '/tmp/test-workspace', 52 | } 53 | 54 | def tearDown(self): 55 | """Clean up after each test""" 56 | reset_config() 57 | 58 | @patch.dict(os.environ, clear=True) 59 | def test_llm_provider_contrast(self): 60 | """Test that llmProvider is set to CONTRAST when USE_CONTRAST_LLM is true""" 61 | test_env = {**self.env_vars, 'USE_CONTRAST_LLM': 'true', 'CODING_AGENT': 'SMARTFIX'} 62 | 63 | with patch.dict(os.environ, test_env): 64 | reset_config() 65 | initialize_telemetry() 66 | data = get_telemetry_data() 67 | 68 | self.assertEqual(data['configInfo']['llmProvider'], LlmProvider.CONTRAST.value) 69 | 70 | @patch.dict(os.environ, clear=True) 71 | def test_llm_provider_byollm(self): 72 | """Test that llmProvider is set to BYOLLM when USE_CONTRAST_LLM is false""" 73 | test_env = {**self.env_vars, 'USE_CONTRAST_LLM': 'false', 'CODING_AGENT': 'SMARTFIX'} 74 | 75 | with patch.dict(os.environ, test_env): 76 | reset_config() 77 | initialize_telemetry() 78 | data = get_telemetry_data() 79 | 80 | self.assertEqual(data['configInfo']['llmProvider'], LlmProvider.BYOLLM.value) 81 | 82 | @patch.dict(os.environ, clear=True) 83 | def test_agent_type_smartfix(self): 84 | """Test that agentType is set correctly for SMARTFIX agent""" 85 | test_env = {**self.env_vars, 'CODING_AGENT': 'SMARTFIX'} 86 | 87 | with patch.dict(os.environ, test_env): 88 | reset_config() 89 | initialize_telemetry() 90 | data = get_telemetry_data() 91 | 92 | self.assertEqual(data['configInfo']['agentType'], CodingAgents.SMARTFIX.value) 93 | 94 | @patch.dict(os.environ, clear=True) 95 | def test_agent_type_github_copilot(self): 96 | """Test that agentType is set correctly for GITHUB_COPILOT agent""" 97 | test_env = {**self.env_vars, 'CODING_AGENT': 'GITHUB_COPILOT'} 98 | 99 | with patch.dict(os.environ, test_env): 100 | reset_config() 101 | initialize_telemetry() 102 | data = get_telemetry_data() 103 | 104 | self.assertEqual(data['configInfo']['agentType'], CodingAgents.GITHUB_COPILOT.value) 105 | 106 | @patch.dict(os.environ, clear=True) 107 | def test_agent_type_claude_code(self): 108 | """Test that agentType is set correctly for CLAUDE_CODE agent""" 109 | test_env = {**self.env_vars, 'CODING_AGENT': 'CLAUDE_CODE'} 110 | 111 | with patch.dict(os.environ, test_env): 112 | reset_config() 113 | initialize_telemetry() 114 | data = get_telemetry_data() 115 | 116 | self.assertEqual(data['configInfo']['agentType'], CodingAgents.CLAUDE_CODE.value) 117 | 118 | @patch.dict(os.environ, clear=True) 119 | def test_full_telemetry_enabled_true(self): 120 | """Test that fullTelemetryEnabled is true when ENABLE_FULL_TELEMETRY is true""" 121 | test_env = {**self.env_vars, 'ENABLE_FULL_TELEMETRY': 'true', 'CODING_AGENT': 'SMARTFIX'} 122 | 123 | with patch.dict(os.environ, test_env): 124 | reset_config() 125 | initialize_telemetry() 126 | data = get_telemetry_data() 127 | 128 | self.assertEqual(data['configInfo']['fullTelemetryEnabled'], True) 129 | 130 | @patch.dict(os.environ, clear=True) 131 | def test_full_telemetry_enabled_false(self): 132 | """Test that fullTelemetryEnabled is false when ENABLE_FULL_TELEMETRY is false""" 133 | test_env = {**self.env_vars, 'ENABLE_FULL_TELEMETRY': 'false', 'CODING_AGENT': 'SMARTFIX'} 134 | 135 | with patch.dict(os.environ, test_env): 136 | reset_config() 137 | initialize_telemetry() 138 | data = get_telemetry_data() 139 | 140 | self.assertEqual(data['configInfo']['fullTelemetryEnabled'], False) 141 | 142 | @patch.dict(os.environ, clear=True) 143 | def test_all_attributes_present(self): 144 | """Test that all three new attributes are present in configInfo""" 145 | test_env = {**self.env_vars, 'CODING_AGENT': 'SMARTFIX'} 146 | 147 | with patch.dict(os.environ, test_env): 148 | reset_config() 149 | initialize_telemetry() 150 | data = get_telemetry_data() 151 | 152 | config_info = data['configInfo'] 153 | self.assertIn('llmProvider', config_info) 154 | self.assertIn('agentType', config_info) 155 | self.assertIn('fullTelemetryEnabled', config_info) 156 | 157 | @patch.dict(os.environ, clear=True) 158 | def test_combined_scenario_contrast_smartfix_full_telemetry(self): 159 | """Test a complete scenario: CONTRAST + SMARTFIX + Full Telemetry""" 160 | test_env = { 161 | **self.env_vars, 162 | 'USE_CONTRAST_LLM': 'true', 163 | 'CODING_AGENT': 'SMARTFIX', 164 | 'ENABLE_FULL_TELEMETRY': 'true' 165 | } 166 | 167 | with patch.dict(os.environ, test_env): 168 | reset_config() 169 | initialize_telemetry() 170 | data = get_telemetry_data() 171 | 172 | config_info = data['configInfo'] 173 | self.assertEqual(config_info['llmProvider'], LlmProvider.CONTRAST.value) 174 | self.assertEqual(config_info['agentType'], CodingAgents.SMARTFIX.value) 175 | self.assertEqual(config_info['fullTelemetryEnabled'], True) 176 | 177 | @patch.dict(os.environ, clear=True) 178 | def test_combined_scenario_byollm_copilot_no_telemetry(self): 179 | """Test a complete scenario: BYOLLM + COPILOT + No Full Telemetry""" 180 | test_env = { 181 | **self.env_vars, 182 | 'USE_CONTRAST_LLM': 'false', 183 | 'CODING_AGENT': 'GITHUB_COPILOT', 184 | 'ENABLE_FULL_TELEMETRY': 'false' 185 | } 186 | 187 | with patch.dict(os.environ, test_env): 188 | reset_config() 189 | initialize_telemetry() 190 | data = get_telemetry_data() 191 | 192 | config_info = data['configInfo'] 193 | self.assertEqual(config_info['llmProvider'], LlmProvider.BYOLLM.value) 194 | self.assertEqual(config_info['agentType'], CodingAgents.GITHUB_COPILOT.value) 195 | self.assertEqual(config_info['fullTelemetryEnabled'], False) 196 | 197 | @patch.dict(os.environ, clear=True) 198 | def test_default_values(self): 199 | """Test that default values are set correctly when env vars are not provided""" 200 | test_env = {**self.env_vars} 201 | # Not setting USE_CONTRAST_LLM, CODING_AGENT, or ENABLE_FULL_TELEMETRY 202 | 203 | with patch.dict(os.environ, test_env): 204 | reset_config() 205 | initialize_telemetry() 206 | data = get_telemetry_data() 207 | 208 | config_info = data['configInfo'] 209 | # Defaults: USE_CONTRAST_LLM=true, CODING_AGENT=SMARTFIX, ENABLE_FULL_TELEMETRY=true 210 | self.assertEqual(config_info['llmProvider'], LlmProvider.CONTRAST.value) 211 | self.assertEqual(config_info['agentType'], CodingAgents.SMARTFIX.value) 212 | self.assertEqual(config_info['fullTelemetryEnabled'], True) 213 | 214 | 215 | if __name__ == '__main__': 216 | unittest.main() 217 | -------------------------------------------------------------------------------- /test/test_session_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # - 3 | # #%L 4 | # Contrast AI SmartFix 5 | # %% 6 | # Copyright (C) 2025 Contrast Security, Inc. 7 | # %% 8 | # Contact: support@contrastsecurity.com 9 | # License: Commercial 10 | # NOTICE: This Software and the patented inventions embodied within may only be 11 | # used as part of Contrast Security's commercial offerings. Even though it is 12 | # made available through public repositories, use of this Software is subject to 13 | # the applicable End User Licensing Agreement found at 14 | # https://www.contrastsecurity.com/enduser-terms-0317a or as otherwise agreed 15 | # between Contrast Security and the End User. The Software may not be reverse 16 | # engineered, modified, repackaged, sold, redistributed or otherwise used in a 17 | # way not consistent with the End User License Agreement. 18 | # #L% 19 | # 20 | 21 | import unittest 22 | from unittest.mock import MagicMock, patch 23 | 24 | from src.smartfix.domains.workflow.session_handler import SessionHandler, QASectionConfig 25 | from src.smartfix.shared.failure_categories import FailureCategory 26 | 27 | 28 | class TestSessionHandler(unittest.TestCase): 29 | """ 30 | Unit tests for the object-oriented session handling logic. 31 | 32 | These tests validate the core business logic that was causing the original bug 33 | where failed sessions could generate false positive success messages. 34 | """ 35 | 36 | def setUp(self): 37 | """Set up test fixtures.""" 38 | self.session_handler = SessionHandler() 39 | 40 | def create_mock_session(self, success=True, qa_attempts=0, failure_category=None, pr_body="Test PR body"): 41 | """Helper to create a mock session object.""" 42 | session = MagicMock() 43 | session.success = success 44 | session.qa_attempts = qa_attempts 45 | session.failure_category = failure_category 46 | session.pr_body = pr_body 47 | return session 48 | 49 | def test_generate_qa_section_success_no_attempts(self): 50 | """Test QA section generation for successful session with no QA attempts.""" 51 | session = self.create_mock_session(success=True, qa_attempts=0) 52 | config = QASectionConfig(skip_qa_review=False, has_build_command=True, build_command="pytest") 53 | 54 | result = self.session_handler.generate_qa_section(session, config) 55 | 56 | self.assertIn("Success (passed on first attempt)", result) 57 | self.assertIn("Build Run:** Yes (`pytest`)", result) 58 | 59 | def test_generate_qa_section_success_with_attempts(self): 60 | """Test QA section generation for successful session with QA attempts.""" 61 | session = self.create_mock_session(success=True, qa_attempts=2) 62 | config = QASectionConfig(skip_qa_review=False, has_build_command=True, build_command="npm test") 63 | 64 | result = self.session_handler.generate_qa_section(session, config) 65 | 66 | self.assertIn("Final Build Status:** Success", result) 67 | self.assertNotIn("passed on first attempt", result) 68 | self.assertIn("Build Run:** Yes (`npm test`)", result) 69 | 70 | def test_generate_qa_section_qa_skipped_no_build_command(self): 71 | """Test QA section when QA is skipped due to no build command.""" 72 | session = self.create_mock_session(success=True, qa_attempts=0) 73 | config = QASectionConfig(skip_qa_review=False, has_build_command=False, build_command="") 74 | 75 | with patch('src.smartfix.domains.workflow.session_handler.log') as mock_log: 76 | result = self.session_handler.generate_qa_section(session, config) 77 | 78 | self.assertEqual(result, "") # Empty section when QA skipped 79 | mock_log.assert_called_with("QA Review was skipped as no BUILD_COMMAND was provided.") 80 | 81 | def test_generate_qa_section_qa_skipped_by_config(self): 82 | """Test QA section when QA is skipped by configuration.""" 83 | session = self.create_mock_session(success=True, qa_attempts=0) 84 | config = QASectionConfig(skip_qa_review=True, has_build_command=True, build_command="make test") 85 | 86 | with patch('src.smartfix.domains.workflow.session_handler.log') as mock_log: 87 | result = self.session_handler.generate_qa_section(session, config) 88 | 89 | self.assertEqual(result, "") # Empty section when QA skipped 90 | mock_log.assert_called_with("QA Review was skipped based on SKIP_QA_REVIEW setting.") 91 | 92 | def test_handle_session_result_success(self): 93 | """Test session result handling for successful session.""" 94 | session = self.create_mock_session(success=True, pr_body="Custom PR body") 95 | 96 | result = self.session_handler.handle_session_result(session) 97 | 98 | self.assertTrue(result.should_continue) 99 | self.assertIsNone(result.failure_category) 100 | self.assertEqual(result.ai_fix_summary, "Custom PR body") 101 | 102 | def test_handle_session_result_success_no_pr_body(self): 103 | """Test session result handling for successful session without PR body.""" 104 | session = self.create_mock_session(success=True, pr_body=None) 105 | 106 | result = self.session_handler.handle_session_result(session) 107 | 108 | self.assertTrue(result.should_continue) 109 | self.assertIsNone(result.failure_category) 110 | self.assertEqual(result.ai_fix_summary, "Fix completed successfully") 111 | 112 | def test_handle_session_result_failure_with_category(self): 113 | """Test session result handling for failed session with failure category.""" 114 | mock_failure_category = MagicMock() 115 | mock_failure_category.value = "INITIAL_BUILD_FAILURE" 116 | session = self.create_mock_session( 117 | success=False, 118 | failure_category=mock_failure_category 119 | ) 120 | 121 | result = self.session_handler.handle_session_result(session) 122 | 123 | self.assertFalse(result.should_continue) 124 | self.assertEqual(result.failure_category, "INITIAL_BUILD_FAILURE") 125 | self.assertIsNone(result.ai_fix_summary) 126 | 127 | def test_handle_session_result_failure_no_category(self): 128 | """Test session result handling for failed session without failure category.""" 129 | session = self.create_mock_session(success=False, failure_category=None) 130 | 131 | result = self.session_handler.handle_session_result(session) 132 | 133 | self.assertFalse(result.should_continue) 134 | self.assertEqual(result.failure_category, FailureCategory.AGENT_FAILURE.value) 135 | self.assertIsNone(result.ai_fix_summary) 136 | 137 | def test_bug_fix_validation(self): 138 | """ 139 | Integration test validating the original bug is fixed. 140 | 141 | This test simulates the exact scenario that was causing false positive 142 | "Success (passed on first attempt)" messages. 143 | """ 144 | # Bug scenario: qa_attempts == 0 AND session failed 145 | failed_session = self.create_mock_session( 146 | success=False, 147 | qa_attempts=0, 148 | failure_category=MagicMock(value="INITIAL_BUILD_FAILURE") 149 | ) 150 | 151 | # This should result in error_exit, not PR generation 152 | result = self.session_handler.handle_session_result(failed_session) 153 | 154 | self.assertFalse(result.should_continue) # Should NOT continue to PR generation 155 | self.assertEqual(result.failure_category, "INITIAL_BUILD_FAILURE") 156 | 157 | # Legitimate success scenario should still work 158 | success_session = self.create_mock_session(success=True, qa_attempts=0) 159 | result = self.session_handler.handle_session_result(success_session) 160 | 161 | self.assertTrue(result.should_continue) # Should continue to PR generation 162 | self.assertIsNone(result.failure_category) 163 | 164 | # Generate QA section for legitimate success 165 | config = QASectionConfig(skip_qa_review=False, has_build_command=True, build_command="pytest") 166 | qa_section = self.session_handler.generate_qa_section(success_session, config) 167 | 168 | self.assertIn("Success (passed on first attempt)", qa_section) 169 | 170 | def test_session_failure_with_qa_attempts(self): 171 | """ 172 | Test that session failure returns should_continue=False even when qa_attempts > 0. 173 | 174 | This specifically tests the scenario where session success is false 175 | but qa_attempts is greater than 0, ensuring proper failure handling. 176 | """ 177 | # Failure scenario: session failed but QA attempts were made 178 | mock_failure_category = MagicMock() 179 | mock_failure_category.value = "QA_BUILD_FAILURE" 180 | failed_session_with_qa = self.create_mock_session( 181 | success=False, 182 | qa_attempts=3, # QA was attempted multiple times 183 | failure_category=mock_failure_category 184 | ) 185 | 186 | result = self.session_handler.handle_session_result(failed_session_with_qa) 187 | 188 | # Should NOT continue regardless of qa_attempts when session failed 189 | self.assertFalse(result.should_continue) 190 | self.assertEqual(result.failure_category, "QA_BUILD_FAILURE") 191 | self.assertIsNone(result.ai_fix_summary) 192 | 193 | 194 | if __name__ == '__main__': 195 | unittest.main() 196 | -------------------------------------------------------------------------------- /test/test_version_check.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from unittest.mock import patch, MagicMock 4 | from packaging.version import Version 5 | 6 | # Test setup imports (path is set up by conftest.py) 7 | from src.config import reset_config 8 | from src.version_check import get_latest_repo_version, check_for_newer_version, do_version_check, normalize_version, safe_parse_version 9 | 10 | 11 | class TestVersionCheck(unittest.TestCase): 12 | """Test the version checking functionality.""" 13 | 14 | def setUp(self): 15 | # Common setup for all tests 16 | reset_config() # Reset config before each test 17 | 18 | self.requests_patcher = patch('src.version_check.requests') 19 | self.mock_requests = self.requests_patcher.start() 20 | 21 | # Set up mock exceptions 22 | self.mock_requests.exceptions = MagicMock() 23 | self.mock_requests.exceptions.RequestException = Exception 24 | 25 | # Set up a default mock response 26 | self.mock_response = MagicMock() 27 | self.mock_response.json.return_value = [ 28 | {'name': 'v1.0.0'}, 29 | {'name': 'v1.1.0'}, 30 | {'name': '2.0.0'} 31 | ] 32 | self.mock_response.raise_for_status.return_value = None 33 | self.mock_requests.get.return_value = self.mock_response 34 | 35 | # Reset debug_log and log mocks before each test 36 | self.log_patcher = patch('src.version_check.log') 37 | self.mock_log = self.log_patcher.start() 38 | 39 | self.debug_log_patcher = patch('src.version_check.debug_log') 40 | self.mock_debug_log = self.debug_log_patcher.start() 41 | 42 | def tearDown(self): 43 | # Clean up all patches 44 | self.requests_patcher.stop() 45 | self.log_patcher.stop() 46 | self.debug_log_patcher.stop() 47 | reset_config() # Reset config after each test 48 | 49 | def test_normalize_version(self): 50 | """Test the normalize_version function.""" 51 | self.assertEqual(normalize_version("v1.0.0"), "1.0.0") 52 | self.assertEqual(normalize_version("1.0.0"), "1.0.0") 53 | self.assertEqual(normalize_version(""), "") 54 | self.assertEqual(normalize_version(None), None) 55 | 56 | def test_safe_parse_version(self): 57 | """Test the safe_parse_version function.""" 58 | self.assertIsInstance(safe_parse_version("1.0.0"), Version) 59 | self.assertIsInstance(safe_parse_version("v1.0.0"), Version) 60 | self.assertIsNone(safe_parse_version("invalid")) 61 | self.assertIsNone(safe_parse_version("")) 62 | self.assertIsNone(safe_parse_version(None)) 63 | 64 | def test_get_latest_repo_version_success(self): 65 | """Test getting the latest version from a repo with valid tags.""" 66 | result = get_latest_repo_version("https://github.com/user/repo") 67 | self.assertEqual(result, "2.0.0") 68 | self.mock_requests.get.assert_called_once_with("https://api.github.com/repos/user/repo/tags") 69 | 70 | def test_get_latest_repo_version_with_v_prefix(self): 71 | """Test that versions with 'v' prefix are handled correctly.""" 72 | # Change mock response to only have v-prefixed tags 73 | self.mock_response.json.return_value = [ 74 | {'name': 'v1.0.0'}, 75 | {'name': 'v2.0.0'}, 76 | {'name': 'v1.5.2'} 77 | ] 78 | result = get_latest_repo_version("https://github.com/user/repo") 79 | self.assertEqual(result, "v2.0.0") 80 | 81 | def test_get_latest_repo_version_no_tags(self): 82 | """Test behavior when no tags are found.""" 83 | self.mock_response.json.return_value = [] 84 | result = get_latest_repo_version("https://github.com/user/repo") 85 | self.assertIsNone(result) 86 | 87 | def test_get_latest_repo_version_request_error(self): 88 | """Test handling of request exceptions.""" 89 | # Create a proper RequestException instance 90 | request_exception = self.mock_requests.exceptions.RequestException("Connection error") 91 | self.mock_requests.get.side_effect = request_exception 92 | result = get_latest_repo_version("https://github.com/user/repo") 93 | self.assertIsNone(result) 94 | 95 | def test_check_for_newer_version_newer_available(self): 96 | """Test when a newer version is available.""" 97 | result = check_for_newer_version("v1.0.0", "v1.1.0") 98 | self.assertEqual(result, "v1.1.0") 99 | 100 | def test_check_for_newer_version_same_version(self): 101 | """Test when versions are the same.""" 102 | result = check_for_newer_version("v1.0.0", "v1.0.0") 103 | self.assertIsNone(result) 104 | 105 | def test_check_for_newer_version_older_version(self): 106 | """Test when the latest version is older.""" 107 | result = check_for_newer_version("v2.0.0", "v1.9.0") 108 | self.assertIsNone(result) 109 | 110 | def test_check_for_newer_version_mixed_formats(self): 111 | """Test version comparison with mixed format (with/without v prefix).""" 112 | result = check_for_newer_version("1.0.0", "v1.1.0") 113 | self.assertEqual(result, "v1.1.0") 114 | 115 | def test_check_for_newer_version_with_version_object(self): 116 | """Test check_for_newer_version with a Version object as first parameter.""" 117 | from packaging.version import parse 118 | version_obj = parse("1.0.0") 119 | result = check_for_newer_version(version_obj, "v1.1.0") 120 | self.assertEqual(result, "v1.1.0") 121 | 122 | @patch.dict('os.environ', {}, clear=True) 123 | def test_do_version_check_no_refs(self): 124 | """Test do_version_check when no reference environment variables are set.""" 125 | # Using patch.dict to ensure environment variables are cleared 126 | do_version_check() 127 | # Check that the appropriate debug_log message was called 128 | self.mock_debug_log.assert_any_call("Warning: Neither GITHUB_ACTION_REF nor GITHUB_REF environment variables are set. Version checking is skipped.") 129 | 130 | @patch.dict('os.environ', {'GITHUB_SHA': 'abcdef1234567890abcdef1234567890abcdef12'}, clear=True) 131 | def test_do_version_check_with_sha_only(self): 132 | """Test do_version_check when only GITHUB_SHA is available.""" 133 | # Using patch.dict to mock only GITHUB_SHA 134 | do_version_check() 135 | # Check that the appropriate debug_log message was called 136 | self.mock_debug_log.assert_any_call("Running from SHA: abcdef1234567890abcdef1234567890abcdef12. No ref found for version check, using SHA.") 137 | 138 | @patch.dict('os.environ', {'GITHUB_REF': 'refs/tags/v1.0.0'}, clear=True) 139 | @patch('src.version_check.get_latest_repo_version') 140 | def test_do_version_check_with_github_ref(self, mock_get_latest): 141 | """Test when GITHUB_REF is set but not GITHUB_ACTION_REF.""" 142 | # Setup environment and mocks 143 | mock_get_latest.return_value = "v2.0.0" 144 | 145 | # Reset mocks before test 146 | self.mock_log.reset_mock() 147 | 148 | do_version_check() 149 | 150 | # Check debug print calls for messages that use debug_log 151 | self.mock_debug_log.assert_any_call("Current action version: v1.0.0") 152 | self.mock_debug_log.assert_any_call("Latest version available in repo: v2.0.0") 153 | 154 | # Check that the log function was called with the newer version message 155 | self.mock_log.assert_any_call("INFO: A newer version of this action is available (v2.0.0).") 156 | 157 | @patch('src.version_check.get_latest_repo_version') 158 | def test_do_version_check_prefers_action_ref(self, mock_get_latest): 159 | """Test that GITHUB_ACTION_REF is preferred over GITHUB_REF.""" 160 | # Setup environment with both variables 161 | os.environ["GITHUB_ACTION_REF"] = "refs/tags/v2.0.0" 162 | os.environ["GITHUB_REF"] = "refs/tags/v1.0.0" # This should be ignored 163 | mock_get_latest.return_value = "v2.0.0" 164 | 165 | do_version_check() 166 | 167 | # Check that the debug_log was called with the correct version from GITHUB_ACTION_REF 168 | self.mock_debug_log.assert_any_call("Current action version: v2.0.0") 169 | 170 | @patch('src.version_check.get_latest_repo_version') 171 | def test_do_version_check_sha_ref(self, mock_get_latest): 172 | """Test with a SHA reference instead of a version tag.""" 173 | os.environ["GITHUB_ACTION_REF"] = "abcdef1234567890abcdef1234567890abcdef12" 174 | 175 | do_version_check() 176 | 177 | # Check debug_log calls 178 | self.mock_debug_log.assert_any_call("Running action from SHA: abcdef1234567890abcdef1234567890abcdef12. Skipping version comparison against tags.") 179 | mock_get_latest.assert_not_called() 180 | 181 | @patch.dict('os.environ', {'GITHUB_REF': 'refs/heads/main'}, clear=True) 182 | @patch('src.version_check.get_latest_repo_version') 183 | def test_do_version_check_unparseable_version(self, mock_get_latest): 184 | """Test with a reference that can't be parsed as a version.""" 185 | do_version_check() 186 | 187 | # Check debug_log calls 188 | self.mock_debug_log.assert_any_call("Running from branch 'main'. Version checking is only meaningful when using release tags.") 189 | mock_get_latest.assert_not_called() 190 | 191 | 192 | if __name__ == '__main__': 193 | unittest.main() 194 | --------------------------------------------------------------------------------