├── .gitignore ├── LICENSE ├── README.md ├── after ├── main.py └── pay │ ├── __init__.py │ ├── order.py │ ├── payment.py │ ├── processor.py │ └── tests │ ├── __init__.py │ ├── test_line_item.py │ ├── test_order.py │ ├── test_payment.py │ └── test_processor.py ├── after_refactor ├── .env ├── main.py └── pay │ ├── __init__.py │ ├── card.py │ ├── order.py │ ├── payment.py │ ├── processor.py │ └── tests │ ├── __init__.py │ ├── test_line_item.py │ ├── test_order.py │ ├── test_payment.py │ └── test_processor.py └── before ├── main.py └── pay ├── __init__.py ├── order.py ├── payment.py └── processor.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | .coverage 3 | htmlcov 4 | **/.DS_Store 5 | **/node_modules 6 | **/*.log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ArjanCodes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How To Write Unit Tests For Existing Python Code 2 | 3 | In this two-part miniseries, I show you a practical example of adding unit tests to existing code. This first part focuses on adding tests while not changing the original code. There are also a few things in the test setup that are not ideal, like how dates are used in the test code, using a real API key and doing actual card charges. I address these things in part 2, where I also show how refactoring the code simplifies test writing while improving the design as well. 4 | 5 | - Part 1: https://youtu.be/ULxMQ57engo. 6 | - Part 2: https://youtu.be/NI5IGAim8XU. 7 | -------------------------------------------------------------------------------- /after/main.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem, Order 2 | from pay.payment import pay_order 3 | 4 | 5 | def main(): 6 | # Test card number: 1249190007575069 7 | order = Order() 8 | order.line_items.append(LineItem(name="Shoes", price=100_00, quantity=2)) 9 | order.line_items.append(LineItem(name="Hat", price=50_00)) 10 | pay_order(order) 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /after/pay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjanCodes/2022-test-existing-code/8c9a8df8c63a7df161b67fa54ba433828a489439/after/pay/__init__.py -------------------------------------------------------------------------------- /after/pay/order.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | 4 | 5 | class OrderStatus(Enum): 6 | OPEN = "open" 7 | PAID = "paid" 8 | 9 | 10 | @dataclass 11 | class LineItem: 12 | name: str 13 | price: int 14 | quantity: int = 1 15 | 16 | @property 17 | def total(self) -> int: 18 | return self.price * self.quantity 19 | 20 | 21 | @dataclass 22 | class Order: 23 | line_items: list[LineItem] = field(default_factory=list) 24 | status: OrderStatus = OrderStatus.OPEN 25 | 26 | @property 27 | def total(self) -> int: 28 | return sum(item.total for item in self.line_items) 29 | 30 | def pay(self) -> None: 31 | self.status = OrderStatus.PAID 32 | -------------------------------------------------------------------------------- /after/pay/payment.py: -------------------------------------------------------------------------------- 1 | from pay.order import Order 2 | from pay.processor import PaymentProcessor 3 | 4 | 5 | def pay_order(order: Order): 6 | if order.total == 0: 7 | raise ValueError("Can't pay an order with total 0.") 8 | card = input("Please enter your card number: ") 9 | month = int(input("Please enter the card expiry month: ")) 10 | year = int(input("Please enter the card expiry year: ")) 11 | payment_processor = PaymentProcessor("6cfb67f3-6281-4031-b893-ea85db0dce20") 12 | payment_processor.charge(card, month, year, amount=order.total) 13 | order.pay() 14 | -------------------------------------------------------------------------------- /after/pay/processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class PaymentProcessor: 5 | def __init__(self, api_key: str) -> None: 6 | self.api_key = api_key 7 | 8 | def _check_api_key(self) -> bool: 9 | return self.api_key == "6cfb67f3-6281-4031-b893-ea85db0dce20" 10 | 11 | def charge(self, card: str, month: int, year: int, amount: int) -> None: 12 | if not self.validate_card(card, month, year): 13 | raise ValueError("Invalid card") 14 | if not self._check_api_key(): 15 | raise ValueError("Invalid API key") 16 | print(f"Charging card number {card} for ${amount/100:.2f}") 17 | 18 | def validate_card(self, card: str, month: int, year: int) -> bool: 19 | return self.luhn_checksum(card) and datetime(year, month, 1) > datetime.now() 20 | 21 | def luhn_checksum(self, card_number: str) -> bool: 22 | def digits_of(card_nr: str): 23 | return [int(d) for d in card_nr] 24 | 25 | digits = digits_of(card_number) 26 | odd_digits = digits[-1::-2] 27 | even_digits = digits[-2::-2] 28 | checksum = 0 29 | checksum += sum(odd_digits) 30 | for digit in even_digits: 31 | checksum += sum(digits_of(str(digit * 2))) 32 | return checksum % 10 == 0 33 | -------------------------------------------------------------------------------- /after/pay/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjanCodes/2022-test-existing-code/8c9a8df8c63a7df161b67fa54ba433828a489439/after/pay/tests/__init__.py -------------------------------------------------------------------------------- /after/pay/tests/test_line_item.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem 2 | 3 | 4 | def test_line_item_total() -> None: 5 | line_item = LineItem(name="Test", price=100) 6 | assert line_item.total == 100 7 | 8 | 9 | def test_line_item_total_quantity() -> None: 10 | line_item = LineItem(name="Test", price=100, quantity=2) 11 | assert line_item.total == 200 12 | -------------------------------------------------------------------------------- /after/pay/tests/test_order.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem, Order 2 | 3 | 4 | def test_empty_order_total() -> None: 5 | order = Order() 6 | assert order.total == 0 7 | 8 | 9 | def test_order_total() -> None: 10 | order = Order() 11 | order.line_items.append(LineItem(name="Test", price=100)) 12 | assert order.total == 100 13 | -------------------------------------------------------------------------------- /after/pay/tests/test_payment.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pay.order import LineItem, Order 3 | from pay.payment import pay_order 4 | from pay.processor import PaymentProcessor 5 | from pytest import MonkeyPatch 6 | 7 | 8 | def test_pay_order(monkeypatch: MonkeyPatch): 9 | inputs = ["1249190007575069", "12", "2024"] 10 | monkeypatch.setattr("builtins.input", lambda _: inputs.pop(0)) 11 | monkeypatch.setattr(PaymentProcessor, "_check_api_key", lambda _: True) 12 | order = Order() 13 | order.line_items.append(LineItem(name="Shoes", price=100_00, quantity=2)) 14 | pay_order(order) 15 | 16 | 17 | def test_pay_order_invalid(monkeypatch: MonkeyPatch): 18 | with pytest.raises(ValueError): 19 | inputs = ["1249190007575069", "12", "2024"] 20 | monkeypatch.setattr("builtins.input", lambda _: inputs.pop(0)) 21 | monkeypatch.setattr(PaymentProcessor, "_check_api_key", lambda _: True) 22 | order = Order() 23 | pay_order(order) 24 | -------------------------------------------------------------------------------- /after/pay/tests/test_processor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pay.processor import PaymentProcessor 3 | 4 | API_KEY = "6cfb67f3-6281-4031-b893-ea85db0dce20" 5 | 6 | 7 | def test_invalid_api_key() -> None: 8 | with pytest.raises(ValueError): 9 | payment_processor = PaymentProcessor("") 10 | payment_processor.charge("1249190007575069", 12, 2024, 100) 11 | 12 | 13 | def test_card_number_valid_date(): 14 | payment_processor = PaymentProcessor(API_KEY) 15 | assert payment_processor.validate_card("1249190007575069", 12, 2024) 16 | 17 | 18 | def test_card_number_invalid_date(): 19 | payment_processor = PaymentProcessor(API_KEY) 20 | assert not payment_processor.validate_card("1249190007575069", 12, 1900) 21 | 22 | 23 | def test_card_number_invalid_luhn(): 24 | payment_processor = PaymentProcessor(API_KEY) 25 | assert not payment_processor.luhn_checksum("1249190007575068") 26 | 27 | 28 | def test_card_number_valid_luhn(): 29 | payment_processor = PaymentProcessor(API_KEY) 30 | assert payment_processor.luhn_checksum("1249190007575069") 31 | 32 | 33 | def test_charge_card_valid(): 34 | payment_processor = PaymentProcessor(API_KEY) 35 | payment_processor.charge("1249190007575069", 12, 2024, 100) 36 | 37 | 38 | def test_charge_card_invalid(): 39 | with pytest.raises(ValueError): 40 | payment_processor = PaymentProcessor(API_KEY) 41 | payment_processor.charge("1249190007575068", 12, 2024, 100) 42 | -------------------------------------------------------------------------------- /after_refactor/.env: -------------------------------------------------------------------------------- 1 | API_KEY=6cfb67f3-6281-4031-b893-ea85db0dce20 -------------------------------------------------------------------------------- /after_refactor/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dotenv import load_dotenv 4 | 5 | from pay.card import CreditCard 6 | from pay.order import LineItem, Order 7 | from pay.payment import pay_order 8 | from pay.processor import PaymentProcessor 9 | 10 | 11 | def read_card_info() -> CreditCard: 12 | card = input("Please enter your card number: ") 13 | month = int(input("Please enter the card expiry month: ")) 14 | year = int(input("Please enter the card expiry year: ")) 15 | return CreditCard(card, month, year) 16 | 17 | 18 | def main(): 19 | load_dotenv() 20 | api_key = os.getenv("API_KEY") or "" 21 | payment_processor = PaymentProcessor(api_key) 22 | # Test card number: 1249190007575069 23 | order = Order() 24 | order.line_items.append(LineItem(name="Shoes", price=100_00, quantity=2)) 25 | order.line_items.append(LineItem(name="Hat", price=50_00)) 26 | 27 | # Read card info from user 28 | card = read_card_info() 29 | pay_order(order, payment_processor, card) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /after_refactor/pay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjanCodes/2022-test-existing-code/8c9a8df8c63a7df161b67fa54ba433828a489439/after_refactor/pay/__init__.py -------------------------------------------------------------------------------- /after_refactor/pay/card.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class CreditCard: 6 | number: str 7 | expiry_month: int 8 | expiry_year: int 9 | -------------------------------------------------------------------------------- /after_refactor/pay/order.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | 4 | 5 | class OrderStatus(Enum): 6 | OPEN = "open" 7 | PAID = "paid" 8 | 9 | 10 | @dataclass 11 | class LineItem: 12 | name: str 13 | price: int 14 | quantity: int = 1 15 | 16 | @property 17 | def total(self) -> int: 18 | return self.price * self.quantity 19 | 20 | 21 | @dataclass 22 | class Order: 23 | line_items: list[LineItem] = field(default_factory=list) 24 | status: OrderStatus = OrderStatus.OPEN 25 | 26 | @property 27 | def total(self) -> int: 28 | return sum(item.total for item in self.line_items) 29 | 30 | def pay(self) -> None: 31 | self.status = OrderStatus.PAID 32 | -------------------------------------------------------------------------------- /after_refactor/pay/payment.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | from pay.card import CreditCard 4 | from pay.order import Order 5 | 6 | 7 | class PaymentProcessor(Protocol): 8 | def charge(self, card: CreditCard, amount: int) -> None: 9 | """Charge the card.""" 10 | 11 | 12 | def pay_order( 13 | order: Order, payment_processor: PaymentProcessor, card: CreditCard 14 | ) -> None: 15 | if order.total == 0: 16 | raise ValueError("Can't pay an order with total 0.") 17 | payment_processor.charge(card, amount=order.total) 18 | order.pay() 19 | -------------------------------------------------------------------------------- /after_refactor/pay/processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from pay.card import CreditCard 4 | 5 | 6 | class PaymentProcessor: 7 | def __init__(self, api_key: str) -> None: 8 | self.api_key = api_key 9 | 10 | def _check_api_key(self) -> bool: 11 | return self.api_key == "6cfb67f3-6281-4031-b893-ea85db0dce20" 12 | 13 | def charge(self, card: CreditCard, amount: int) -> None: 14 | if not self.validate_card(card): 15 | raise ValueError("Invalid card") 16 | if not self._check_api_key(): 17 | raise ValueError(f"Invalid API key: {self.api_key}") 18 | print(f"Charging card number {card.number} for ${amount/100:.2f}") 19 | 20 | def validate_card(self, card: CreditCard) -> bool: 21 | return ( 22 | luhn_checksum(card.number) 23 | and datetime(card.expiry_year, card.expiry_month, 1) > datetime.now() 24 | ) 25 | 26 | 27 | def luhn_checksum(card_number: str) -> bool: 28 | def digits_of(card_nr: str): 29 | return [int(d) for d in card_nr] 30 | 31 | digits = digits_of(card_number) 32 | odd_digits = digits[-1::-2] 33 | even_digits = digits[-2::-2] 34 | checksum = 0 35 | checksum += sum(odd_digits) 36 | for digit in even_digits: 37 | checksum += sum(digits_of(str(digit * 2))) 38 | return checksum % 10 == 0 39 | -------------------------------------------------------------------------------- /after_refactor/pay/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjanCodes/2022-test-existing-code/8c9a8df8c63a7df161b67fa54ba433828a489439/after_refactor/pay/tests/__init__.py -------------------------------------------------------------------------------- /after_refactor/pay/tests/test_line_item.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem 2 | 3 | 4 | def test_line_item_total() -> None: 5 | line_item = LineItem(name="Test", price=100) 6 | assert line_item.total == 100 7 | 8 | 9 | def test_line_item_total_quantity() -> None: 10 | line_item = LineItem(name="Test", price=100, quantity=2) 11 | assert line_item.total == 200 12 | -------------------------------------------------------------------------------- /after_refactor/pay/tests/test_order.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem, Order 2 | 3 | 4 | def test_empty_order_total() -> None: 5 | order = Order() 6 | assert order.total == 0 7 | 8 | 9 | def test_order_total() -> None: 10 | order = Order() 11 | order.line_items.append(LineItem(name="Test", price=100)) 12 | assert order.total == 100 13 | -------------------------------------------------------------------------------- /after_refactor/pay/tests/test_payment.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import pytest 4 | from pay.card import CreditCard 5 | from pay.order import LineItem, Order, OrderStatus 6 | from pay.payment import pay_order 7 | 8 | 9 | @pytest.fixture 10 | def card() -> CreditCard: 11 | year = date.today().year + 2 12 | return CreditCard("1249190007575069", 12, year) 13 | 14 | 15 | class PaymentProcessorMock: 16 | def charge(self, card: CreditCard, amount: int) -> None: 17 | print(f"Charging card {card.number} for {amount}.") 18 | 19 | 20 | def test_pay_order(card: CreditCard) -> None: 21 | order = Order() 22 | order.line_items.append(LineItem(name="Shoes", price=100_00, quantity=2)) 23 | pay_order(order, PaymentProcessorMock(), card) 24 | assert order.status == OrderStatus.PAID 25 | 26 | 27 | def test_pay_order_invalid(card: CreditCard) -> None: 28 | with pytest.raises(ValueError): 29 | order = Order() 30 | pay_order(order, PaymentProcessorMock(), card) 31 | -------------------------------------------------------------------------------- /after_refactor/pay/tests/test_processor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import date 3 | 4 | import pytest 5 | from dotenv import load_dotenv 6 | from pay.card import CreditCard 7 | from pay.processor import PaymentProcessor, luhn_checksum 8 | 9 | load_dotenv() 10 | 11 | API_KEY = os.getenv("API_KEY") or "" 12 | 13 | CC_YEAR = date.today().year + 2 14 | 15 | 16 | @pytest.fixture 17 | def payment_processor() -> PaymentProcessor: 18 | return PaymentProcessor(API_KEY) 19 | 20 | 21 | def test_invalid_api_key() -> None: 22 | with pytest.raises(ValueError): 23 | card = CreditCard("1249190007575069", 12, CC_YEAR) 24 | PaymentProcessor("").charge(card, 100) 25 | 26 | 27 | def test_card_number_valid_date(payment_processor: PaymentProcessor) -> None: 28 | card = CreditCard("1249190007575069", 12, CC_YEAR) 29 | assert payment_processor.validate_card(card) 30 | 31 | 32 | def test_card_number_invalid_date(payment_processor: PaymentProcessor) -> None: 33 | card = CreditCard("1249190007575069", 12, 1900) 34 | assert not payment_processor.validate_card(card) 35 | 36 | 37 | def test_card_number_invalid_luhn() -> None: 38 | assert not luhn_checksum("1249190007575068") 39 | 40 | 41 | def test_card_number_valid_luhn() -> None: 42 | assert luhn_checksum("1249190007575069") 43 | 44 | 45 | def test_charge_card_valid(payment_processor: PaymentProcessor) -> None: 46 | card = CreditCard("1249190007575069", 12, CC_YEAR) 47 | payment_processor.charge(card, 100) 48 | 49 | 50 | def test_charge_card_invalid(payment_processor: PaymentProcessor) -> None: 51 | with pytest.raises(ValueError): 52 | card = CreditCard("1249190007575068", 12, CC_YEAR) 53 | payment_processor.charge(card, 100) 54 | -------------------------------------------------------------------------------- /before/main.py: -------------------------------------------------------------------------------- 1 | from pay.order import LineItem, Order 2 | from pay.payment import pay_order 3 | 4 | 5 | def main(): 6 | # Test card number: 1249190007575069 7 | order = Order() 8 | order.line_items.append(LineItem(name="Shoes", price=100_00, quantity=2)) 9 | order.line_items.append(LineItem(name="Hat", price=50_00)) 10 | pay_order(order) 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /before/pay/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArjanCodes/2022-test-existing-code/8c9a8df8c63a7df161b67fa54ba433828a489439/before/pay/__init__.py -------------------------------------------------------------------------------- /before/pay/order.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from enum import Enum 3 | 4 | 5 | class OrderStatus(Enum): 6 | OPEN = "open" 7 | PAID = "paid" 8 | 9 | 10 | @dataclass 11 | class LineItem: 12 | name: str 13 | price: int 14 | quantity: int = 1 15 | 16 | @property 17 | def total(self) -> int: 18 | return self.price * self.quantity 19 | 20 | 21 | @dataclass 22 | class Order: 23 | line_items: list[LineItem] = field(default_factory=list) 24 | status: OrderStatus = OrderStatus.OPEN 25 | 26 | @property 27 | def total(self) -> int: 28 | return sum(item.total for item in self.line_items) 29 | 30 | def pay(self) -> None: 31 | self.status = OrderStatus.PAID 32 | -------------------------------------------------------------------------------- /before/pay/payment.py: -------------------------------------------------------------------------------- 1 | from pay.order import Order 2 | from pay.processor import PaymentProcessor 3 | 4 | 5 | def pay_order(order: Order): 6 | if order.total == 0: 7 | raise ValueError("Can't pay an order with total 0.") 8 | card = input("Please enter your card number: ") 9 | month = int(input("Please enter the card expiry month: ")) 10 | year = int(input("Please enter the card expiry year: ")) 11 | payment_processor = PaymentProcessor("6cfb67f3-6281-4031-b893-ea85db0dce20") 12 | payment_processor.charge(card, month, year, amount=order.total) 13 | order.pay() 14 | -------------------------------------------------------------------------------- /before/pay/processor.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class PaymentProcessor: 5 | def __init__(self, api_key: str) -> None: 6 | self.api_key = api_key 7 | 8 | def _check_api_key(self) -> bool: 9 | return self.api_key == "6cfb67f3-6281-4031-b893-ea85db0dce20" 10 | 11 | def charge(self, card: str, month: int, year: int, amount: int) -> None: 12 | if not self.validate_card(card, month, year): 13 | raise ValueError("Invalid card") 14 | if not self._check_api_key(): 15 | raise ValueError("Invalid API key") 16 | print(f"Charging card number {card} for ${amount/100:.2f}") 17 | 18 | def validate_card(self, card: str, month: int, year: int) -> bool: 19 | return self.luhn_checksum(card) and datetime(year, month, 1) > datetime.now() 20 | 21 | def luhn_checksum(self, card_number: str) -> bool: 22 | def digits_of(card_nr: str): 23 | return [int(d) for d in card_nr] 24 | 25 | digits = digits_of(card_number) 26 | odd_digits = digits[-1::-2] 27 | even_digits = digits[-2::-2] 28 | checksum = 0 29 | checksum += sum(odd_digits) 30 | for digit in even_digits: 31 | checksum += sum(digits_of(str(digit * 2))) 32 | return checksum % 10 == 0 33 | --------------------------------------------------------------------------------