├── .gitignore ├── README.md ├── docker-compose.yml └── src ├── Dockerfile ├── app ├── __init__.py ├── api │ ├── __init__.py │ ├── dependencies.py │ └── endpoints │ │ ├── __init__.py │ │ ├── auth.py │ │ └── url.py ├── core │ ├── __init__.py │ ├── config.py │ └── security.py ├── db │ ├── __init__.py │ ├── url_crud.py │ ├── url_mongodb.py │ └── url_redis.py ├── main.py ├── models │ ├── __init__.py │ ├── domain.py │ └── schemas.py ├── services │ ├── __init__.py │ ├── auth_service.py │ └── url_service.py └── utils │ ├── __init__.py │ ├── rate_limiter.py │ ├── snowflake_utils.py │ └── url_utils.py ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── test_security.py ├── test_url.py └── test_url_service.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea* 2 | .DS_Store 3 | *.pyc 4 | **/__pycache__ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # URL Shortener Service with FastAPI、MongoDB and Pytest 2 | 3 | ## Build the image and run the container: 4 | 5 | ```sh 6 | docker-compose up -d --build 7 | ``` 8 | 9 | ## Update the image and run unit test: 10 | 11 | ```sh 12 | docker-compose exec web pytest . 13 | ``` 14 | 15 | Test out the following routes: 16 | 17 | 1. [http://localhost:8002/](http://localhost:8002/) 18 | 1. [http://localhost:8002/docs](http://localhost:8002/docs) 19 | 1. [http://localhost:8002/api/v1/shorten](http://localhost:8002/api/v1/shorten) 20 | 1. [http://localhost:8002/api/v1/shorten/batch](http://localhost:8002/api/v1/shorten/batch) 21 | 1. [http://localhost:8002/api/v1/shorten/batch_conc](http://localhost:8002/api/v1/shorten/batch_conc) 22 | 1. [http://localhost:8002/api/v1/{short_url}](http://localhost:8002/api/v1/{short_url}) 23 | 1. [http://localhost:8002/api/v1/urls/delete/{short_url}](http://localhost:8002/api/v1/urls/delete/{short_url}) 24 | 1. [http://localhost:8002/test-delete](http://localhost:8002/test-delete) 25 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | web: 3 | build: ./src 4 | command: | 5 | bash -c " 6 | timeout=30 7 | while ! nc -z db 27017; do 8 | sleep 1 9 | timeout=$$((timeout - 1)) 10 | if [ $$timeout -le 0 ]; then 11 | echo 'MongoDB did not become ready in time.' 12 | exit 1 13 | fi 14 | done 15 | uvicorn app.main:app --workers 1 --host 0.0.0.0 --port 8000 16 | " 17 | volumes: 18 | - ./src/:/usr/src/app/ 19 | ports: 20 | - "8002:8000" 21 | depends_on: 22 | - db 23 | 24 | db: 25 | image: mongo:7.0.16-jammy 26 | volumes: 27 | - mongo_data:/var/lib/mongo 28 | ports: 29 | - "27017:27017" 30 | healthcheck: 31 | test: [ "CMD", "mongosh", "--eval", "db.adminCommand('ping')" ] 32 | interval: 300s 33 | timeout: 10s 34 | retries: 3 35 | 36 | volumes: 37 | mongo_data: 38 | -------------------------------------------------------------------------------- /src/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.13.2-alpine 2 | 3 | WORKDIR /usr/src/app 4 | 5 | ENV PYTHONDONTWRITEBYTECODE 1 6 | ENV PYTHONUNBUFFERED 1 7 | 8 | COPY requirements.txt /usr/src/app/requirements.txt 9 | 10 | RUN set -eux \ 11 | && apk add --no-cache --virtual .build-deps build-base \ 12 | libffi-dev gcc musl-dev python3-dev bash \ 13 | && pip install --upgrade pip setuptools wheel \ 14 | && pip install -r /usr/src/app/requirements.txt \ 15 | && rm -rf /root/.cache/pip 16 | 17 | COPY . /usr/src/app/ 18 | -------------------------------------------------------------------------------- /src/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/__init__.py -------------------------------------------------------------------------------- /src/app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/api/__init__.py -------------------------------------------------------------------------------- /src/app/api/dependencies.py: -------------------------------------------------------------------------------- 1 | from app.db.url_mongodb import get_db 2 | from motor.motor_asyncio import AsyncIOMotorClient 3 | 4 | 5 | async def get_database_client() -> AsyncIOMotorClient: 6 | db = get_db() 7 | try: 8 | yield db 9 | finally: 10 | pass 11 | -------------------------------------------------------------------------------- /src/app/api/endpoints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/api/endpoints/__init__.py -------------------------------------------------------------------------------- /src/app/api/endpoints/auth.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/api/endpoints/auth.py -------------------------------------------------------------------------------- /src/app/api/endpoints/url.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, Depends 2 | from fastapi.responses import RedirectResponse 3 | from motor.motor_asyncio import AsyncIOMotorClient 4 | from app.models import schemas 5 | from app.api.dependencies import get_database_client 6 | from app.services.url_service import URLShortener 7 | from app.db.url_crud import (create_url, 8 | get_url_by_short, 9 | update_access_count, 10 | delete_url, 11 | create_urls_concurrently) 12 | 13 | router = APIRouter() 14 | 15 | url_shortener = URLShortener() 16 | 17 | 18 | @router.post("/shorten/", response_model=schemas.URLResponse) 19 | async def create_short_url(url_request: schemas.URLCreate, db: AsyncIOMotorClient = Depends(get_database_client)): 20 | short_url = url_shortener.add_url(url_request.long_url) 21 | 22 | url_model = await create_url(db, url_request.long_url, short_url) 23 | return schemas.URLResponse(long_url=url_model.long_url, short_url=url_model.short_url) 24 | 25 | 26 | @router.post("/shorten/batch/", response_model=list[schemas.URLResponse]) 27 | async def create_short_urls(urls: list[schemas.URLCreate], db: AsyncIOMotorClient = Depends(get_database_client)): 28 | responses = [] 29 | for url_request in urls: 30 | short_url = url_shortener.add_url(url_request.long_url) 31 | url_model = await create_url(db, url_request.long_url, short_url) 32 | responses.append(schemas.URLResponse(long_url=url_model.long_url, short_url=url_model.short_url)) 33 | return responses 34 | 35 | 36 | @router.post("/shorten/batch_conc/", response_model=list[schemas.URLResponse]) 37 | async def create_short_urls_concurrently(urls: list[schemas.URLCreate], 38 | db: AsyncIOMotorClient = Depends(get_database_client)): 39 | if not urls: 40 | raise HTTPException(status_code=400, detail="URL list is empty") 41 | 42 | url_pairs = [ 43 | { 44 | "long_url": str(url_request.long_url), 45 | "short_url": url_shortener.add_url(str(url_request.long_url)) 46 | } 47 | for url_request in urls 48 | ] 49 | 50 | await create_urls_concurrently(db, url_pairs) 51 | 52 | responses = [ 53 | schemas.URLResponse(long_url=pair["long_url"], short_url=pair["short_url"]) 54 | for pair in url_pairs 55 | ] 56 | return responses 57 | 58 | 59 | @router.get("/{short_url:path}") 60 | async def redirect_to_long_url(short_url: str, db: AsyncIOMotorClient = Depends(get_database_client)): 61 | url_model = await get_url_by_short(db, short_url) 62 | if url_model is None: 63 | raise HTTPException(status_code=404, detail="URL not found") 64 | await update_access_count(db, short_url) 65 | response_model = schemas.URLResponse.from_raw(short_url=short_url, long_url=url_model.long_url) 66 | return RedirectResponse(url=response_model.long_url) 67 | 68 | 69 | @router.delete("/urls/delete/{short_url:path}") 70 | async def delete_short_url(short_url: str, db: AsyncIOMotorClient = Depends(get_database_client)): 71 | success = await delete_url(db, short_url) 72 | if success: 73 | return {"detail": "URL deleted successfully."} 74 | else: 75 | raise HTTPException(status_code=404, detail="URL not found or already deleted.") 76 | -------------------------------------------------------------------------------- /src/app/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/core/__init__.py -------------------------------------------------------------------------------- /src/app/core/config.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings, SettingsConfigDict 2 | 3 | 4 | class Settings(BaseSettings): 5 | database_url: str = "mongodb://db:27017" 6 | database_name: str = "url_shortener" 7 | database_collection_name: str = "urls" 8 | 9 | secret_key: str = "bin_secret_key" 10 | algorithm: str = "HS256" 11 | access_token_expire_minutes: int = 30 12 | enable_rate_limit: bool = False 13 | ENABLE_DOCS: bool = True 14 | 15 | model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") 16 | 17 | 18 | settings = Settings() 19 | -------------------------------------------------------------------------------- /src/app/core/security.py: -------------------------------------------------------------------------------- 1 | import jwt 2 | import bcrypt 3 | from datetime import datetime, timedelta, timezone 4 | from app.core.config import settings 5 | 6 | 7 | def encrypt_password(password: str) -> str: 8 | hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt()) 9 | return hashed.decode("utf-8") 10 | 11 | 12 | def verify_password(plain_password: str, hashed_password: str) -> bool: 13 | return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) 14 | 15 | 16 | def create_access_token(data: dict): 17 | to_encode = data.copy() 18 | expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes) 19 | to_encode.update({"exp": expire}) 20 | 21 | encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) 22 | 23 | return encoded_jwt 24 | 25 | 26 | def verify_access_token(token: str, credentials_exception): 27 | try: 28 | payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm], 29 | options={"verify_aud": False}) 30 | return payload 31 | except jwt.PyJWTError: 32 | raise credentials_exception 33 | -------------------------------------------------------------------------------- /src/app/db/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/db/__init__.py -------------------------------------------------------------------------------- /src/app/db/url_crud.py: -------------------------------------------------------------------------------- 1 | import time 2 | import asyncio 3 | from typing import Optional 4 | from motor.motor_asyncio import AsyncIOMotorClient 5 | from app.core.config import settings 6 | from app.models.domain import URLModel 7 | from app.models.schemas import URLInDB 8 | 9 | COLLECTION_NAME = settings.database_collection_name 10 | 11 | 12 | async def create_url(db: AsyncIOMotorClient, long_url: str, short_url: str) -> URLModel: 13 | existing_url = await db[COLLECTION_NAME].find_one({"long_url": str(long_url)}) 14 | if existing_url: 15 | return URLModel(**existing_url) 16 | 17 | document = { 18 | "long_url": str(long_url), 19 | "short_url": short_url, 20 | "created_at": time.time_ns(), 21 | } 22 | await db[COLLECTION_NAME].insert_one(document) 23 | return URLModel(**document) 24 | 25 | 26 | async def create_urls_concurrently(db: AsyncIOMotorClient, url_pairs: list[dict]) -> list[URLModel]: 27 | tasks = [] 28 | for pair in url_pairs: 29 | pair["long_url"] = str(pair["long_url"]) 30 | pair["short_url"] = str(pair["short_url"]) 31 | 32 | pair["created_at"] = pair.get("created_at", time.time_ns()) 33 | 34 | task = db[COLLECTION_NAME].update_one( 35 | {"long_url": pair["long_url"]}, 36 | {"$setOnInsert": pair}, 37 | upsert=True 38 | ) 39 | tasks.append(task) 40 | 41 | await asyncio.gather(*tasks) 42 | 43 | return [URLModel(**pair) for pair in url_pairs] 44 | 45 | 46 | async def get_url_by_short(db: AsyncIOMotorClient, short_url: str) -> Optional[URLModel]: 47 | document = await db[COLLECTION_NAME].find_one({"short_url": short_url}) 48 | 49 | return URLModel(**{**document, "id": str(document.pop("_id"))}) if document else None 50 | 51 | 52 | async def update_access_count_bak(db: AsyncIOMotorClient, short_url: str) -> bool: 53 | result = await db[COLLECTION_NAME].update_one({"short_url": short_url}, 54 | {"$inc": {"access_count": 1}}) 55 | return result.modified_count > 0 56 | 57 | 58 | async def update_access_count(db: AsyncIOMotorClient, short_url: str) -> Optional[URLInDB]: 59 | result = await db[COLLECTION_NAME].update_one({"short_url": short_url}, 60 | {"$inc": {"access_count": 1}}) 61 | 62 | if result.modified_count > 0: 63 | updated_record = await db[COLLECTION_NAME].find_one({"short_url": short_url}) 64 | if updated_record: 65 | updated_record["id"] = str(updated_record.pop("_id")) 66 | return URLInDB(**updated_record) 67 | 68 | return None 69 | 70 | 71 | async def delete_url(db: AsyncIOMotorClient, short_url: str) -> bool: 72 | result = await db[COLLECTION_NAME].delete_one({"short_url": short_url}) 73 | return result.deleted_count > 0 74 | 75 | -------------------------------------------------------------------------------- /src/app/db/url_mongodb.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from motor.motor_asyncio import AsyncIOMotorClient 3 | from app.core.config import settings 4 | 5 | 6 | class MongoDB: 7 | client: Optional[AsyncIOMotorClient] = None 8 | 9 | 10 | def get_db() -> AsyncIOMotorClient: 11 | assert MongoDB.client is not None, "MongoDB client has not been initialized." 12 | 13 | return MongoDB.client[settings.database_name] 14 | -------------------------------------------------------------------------------- /src/app/db/url_redis.py: -------------------------------------------------------------------------------- 1 | import redis 2 | from typing import Optional 3 | 4 | 5 | class Redis: 6 | connection: Optional[redis.Redis] = None 7 | 8 | 9 | def get_redis_connection(host: str, port: int, db: int) -> redis.Redis: 10 | if Redis.connection is None: 11 | Redis.connection = redis.Redis(host=host, port=port, db=db, decode_responses=True) 12 | 13 | return Redis.connection 14 | 15 | 16 | def get_redis_instance_pool(host: str, port: int, db: int, password: Optional[str] = None) -> redis.Redis: 17 | redis_pool = redis.ConnectionPool(host=host, port=port, db=db, password=password) 18 | redis_instance = redis.StrictRedis(connection_pool=redis_pool) 19 | 20 | return redis_instance 21 | -------------------------------------------------------------------------------- /src/app/main.py: -------------------------------------------------------------------------------- 1 | from contextlib import asynccontextmanager 2 | from fastapi import FastAPI 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from motor.motor_asyncio import AsyncIOMotorClient 5 | from app.api.endpoints import url as url_router 6 | from app.core.config import settings 7 | from app.db.url_mongodb import MongoDB 8 | from app.utils.rate_limiter import RateLimiter 9 | 10 | 11 | @asynccontextmanager 12 | async def app_lifespan(app_instance: FastAPI): 13 | print("Application Shortener startup.") 14 | MongoDB.client = AsyncIOMotorClient(settings.database_url) 15 | print("MongoDB client has been created.") 16 | yield 17 | MongoDB.client.close() 18 | print("MongoDB client has been shut down.") 19 | 20 | 21 | def create_app() -> FastAPI: 22 | ENABLE_DOCS = settings.ENABLE_DOCS 23 | _app = FastAPI(docs_url="/docs" if ENABLE_DOCS else None, 24 | redoc_url="/redoc" if ENABLE_DOCS else None, 25 | lifespan=app_lifespan) 26 | 27 | _app.add_middleware( 28 | CORSMiddleware, 29 | allow_origins=["*"], 30 | allow_credentials=True, 31 | allow_methods=["*"], 32 | allow_headers=["*"], 33 | ) 34 | 35 | if settings.enable_rate_limit: 36 | _app.add_middleware(RateLimiter, max_requests=2, period=60) 37 | 38 | _app.include_router(url_router.router, prefix="/api/v1") 39 | 40 | @_app.get("/") 41 | async def root(): 42 | return {"message": "Welcome to the URL Shortener API"} 43 | 44 | @_app.delete("/test-delete") 45 | async def test_delete(): 46 | return {"message": "DELETE method is allowed"} 47 | 48 | return _app 49 | 50 | 51 | app = create_app() 52 | -------------------------------------------------------------------------------- /src/app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/models/__init__.py -------------------------------------------------------------------------------- /src/app/models/domain.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel, HttpUrl 3 | 4 | 5 | class URLModel(BaseModel): 6 | id: Optional[str] = None 7 | long_url: HttpUrl 8 | short_url: Optional[str] = None 9 | created_at: Optional[int] = None 10 | expires_at: Optional[int] = None 11 | access_count: int = 0 12 | # user_id: Optional[int] = None 13 | 14 | class Config: 15 | json_encoders = { 16 | HttpUrl: lambda v: str(v), 17 | } 18 | -------------------------------------------------------------------------------- /src/app/models/schemas.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pydantic import BaseModel, HttpUrl, Field, field_validator 3 | 4 | 5 | class URLCreate(BaseModel): 6 | long_url: HttpUrl | str 7 | 8 | @classmethod 9 | @field_validator("long_url", mode="before") 10 | def convert_long_url_to_str(cls, values): 11 | if "long_url" in values and isinstance(values["long_url"], HttpUrl): 12 | values["long_url"] = str(values["long_url"]) 13 | return values 14 | 15 | 16 | class URLResponse(BaseModel): 17 | short_url: str 18 | long_url: HttpUrl | str 19 | created_at: Optional[int] = None 20 | 21 | @classmethod 22 | def from_raw(cls, short_url: str, long_url: HttpUrl | str, 23 | created_at: Optional[int] = None): 24 | long_url_str = str(long_url) if isinstance(long_url, HttpUrl) else long_url 25 | return cls(short_url=short_url, long_url=long_url_str, created_at=created_at) 26 | 27 | 28 | class URLInDB(URLResponse): 29 | id: str 30 | expires_at: Optional[int] = None 31 | access_count: int = Field(default=0, description="URL被访问的次数。") 32 | 33 | -------------------------------------------------------------------------------- /src/app/services/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/services/__init__.py -------------------------------------------------------------------------------- /src/app/services/auth_service.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/services/auth_service.py -------------------------------------------------------------------------------- /src/app/services/url_service.py: -------------------------------------------------------------------------------- 1 | import string 2 | import uuid 3 | import concurrent 4 | import threading 5 | from functools import lru_cache 6 | from concurrent.futures import ThreadPoolExecutor 7 | from app.utils.snowflake_utils import SnowflakeIDGenerator 8 | 9 | 10 | class URLShortener: 11 | def __init__(self): 12 | self.id_to_url = {} 13 | self.url_to_id = {} 14 | self.alphabet = string.digits + string.ascii_letters 15 | self.base = len(self.alphabet) 16 | self.lock = threading.Lock() 17 | 18 | @lru_cache(maxsize=1024) 19 | def encode(self, unique_id: int) -> str: 20 | if unique_id <= 0: 21 | raise ValueError("unique_id must be a positive integer") 22 | 23 | parts = [] 24 | while unique_id: 25 | unique_id, rem = divmod(unique_id, self.base) 26 | parts.append(self.alphabet[rem]) 27 | 28 | return "".join(reversed(parts)) 29 | 30 | def add_url(self, long_url: str) -> str: 31 | with self.lock: 32 | if not long_url: 33 | raise ValueError("The input long_url cannot be empty.") 34 | 35 | if long_url in self.url_to_id: 36 | return f"http://short.url/{self.url_to_id[long_url]}" 37 | unique_id = uuid.uuid4().int & ((1 << 64) - 1) 38 | short_path = self.encode(unique_id) 39 | self.url_to_id[long_url] = short_path 40 | self.id_to_url[short_path] = long_url 41 | 42 | return f"http://short.url/{short_path}" 43 | 44 | @lru_cache(maxsize=1024) 45 | def get_long_url(self, short_url: str) -> str | None: 46 | if not short_url.startswith("http://short.url/"): 47 | return None 48 | 49 | short_path = short_url.split("/")[-1] 50 | 51 | if not short_path.isalnum(): 52 | return None 53 | 54 | return self.id_to_url.get(short_path) 55 | 56 | def add_urls(self, urls: str | list[str]) -> list[str]: 57 | if isinstance(urls, str): 58 | urls = [urls] 59 | 60 | return [self.add_url(url) for url in urls] 61 | 62 | def add_urls_concurrently(self, urls: list[str], max_workers: int = 10) -> list[str]: 63 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 64 | future_to_url = {executor.submit(self.add_url, url): url for url in urls} 65 | results = [] 66 | for future in concurrent.futures.as_completed(future_to_url): 67 | results.append(future.result()) 68 | 69 | return results 70 | 71 | 72 | if __name__ == "__main__": 73 | shortener = URLShortener() 74 | _long_url = "https://example.bin.com/" 75 | _short_url = shortener.add_url(_long_url) 76 | print(f"Short URL: {_short_url}") 77 | 78 | retrieved_long_url = shortener.get_long_url(_short_url) 79 | print(f"Retrieved Long URL: {retrieved_long_url}") 80 | print() 81 | 82 | _long_urls = ["https://example.fedora.com/very/long/url", 83 | "https://example.fedora.com/very2/long/url", 84 | "https://example.fedora.com/very3/long/url"] 85 | for _url in shortener.add_urls_concurrently(_long_urls): 86 | print(f"Short URL: {_url}") 87 | print() 88 | 89 | _unique_id = uuid.uuid4().int & (1 << 64) - 1 90 | print(f"unique_id: {_unique_id}") 91 | print(f"Encode base62: {shortener.encode(_unique_id)}") 92 | print() 93 | 94 | id_generator = SnowflakeIDGenerator(data_center_id=0, worker_id=0) 95 | distributed_id = id_generator.get_id() 96 | print(f"Distributed_id: {distributed_id}") 97 | print(f"Encode base62: {shortener.encode(distributed_id)}") 98 | -------------------------------------------------------------------------------- /src/app/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/app/utils/__init__.py -------------------------------------------------------------------------------- /src/app/utils/rate_limiter.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint 3 | from starlette.responses import Response, JSONResponse 4 | from starlette.requests import Request 5 | 6 | 7 | class RateLimiter(BaseHTTPMiddleware): 8 | 9 | def __init__(self, app, max_requests: int, period: int): 10 | super().__init__(app) 11 | self.max_requests = max_requests 12 | self.period = period 13 | self.requests: dict[str, tuple[int, float]] = {} 14 | 15 | async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: 16 | client_ip = request.client.host 17 | current_time = time() 18 | if client_ip not in self.requests: 19 | self.requests[client_ip] = (1, current_time) 20 | else: 21 | count, start_time = self.requests[client_ip] 22 | if current_time - start_time < self.period: 23 | if count >= self.max_requests: 24 | return JSONResponse( 25 | status_code=429, 26 | content={"detail": "请求过于频繁,请稍后再试。"} 27 | ) 28 | self.requests[client_ip] = (count + 1, start_time) 29 | else: 30 | self.requests[client_ip] = (1, current_time) 31 | 32 | return await call_next(request) 33 | -------------------------------------------------------------------------------- /src/app/utils/snowflake_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import threading 3 | 4 | 5 | class SnowflakeIDGenerator: 6 | def __init__(self, data_center_id, worker_id): 7 | self.epoch = 1585644268888 8 | self.data_center_id_bits = 5 9 | self.worker_id_bits = 5 10 | self.sequence_bits = 12 11 | self.max_data_center_id = -1 ^ (-1 << self.data_center_id_bits) 12 | self.max_worker_id = -1 ^ (-1 << self.worker_id_bits) 13 | self.sequence_mask = -1 ^ (-1 << self.sequence_bits) 14 | 15 | self.worker_id_shift = self.sequence_bits 16 | self.data_center_id_shift = self.sequence_bits + self.worker_id_bits 17 | self.timestamp_left_shift = self.sequence_bits + self.worker_id_bits + self.data_center_id_bits 18 | 19 | if data_center_id > self.max_data_center_id or data_center_id < 0: 20 | raise ValueError("data_center_id超出范围") 21 | if worker_id > self.max_worker_id or worker_id < 0: 22 | raise ValueError("worker_id超出范围") 23 | 24 | self.data_center_id = data_center_id 25 | self.worker_id = worker_id 26 | self.sequence = 0 27 | self.last_timestamp = -1 28 | self.lock = threading.Lock() 29 | 30 | def _gen_timestamp(self): 31 | return int(time.time() * 1000) 32 | 33 | def _next_millis(self, last_timestamp: int): 34 | timestamp = self._gen_timestamp() 35 | while timestamp <= last_timestamp: 36 | timestamp = self._gen_timestamp() 37 | return timestamp 38 | 39 | def get_id(self): 40 | with self.lock: 41 | timestamp = self._gen_timestamp() 42 | if timestamp < self.last_timestamp: 43 | raise Exception("时钟回拨,无法生成ID") 44 | if timestamp == self.last_timestamp: 45 | self.sequence = (self.sequence + 1) & self.sequence_mask 46 | if self.sequence == 0: 47 | timestamp = self._next_millis(self.last_timestamp) 48 | else: 49 | self.sequence = 0 50 | 51 | self.last_timestamp = timestamp 52 | 53 | new_id = ((timestamp - self.epoch) << self.timestamp_left_shift) | \ 54 | (self.data_center_id << self.data_center_id_shift) | \ 55 | (self.worker_id << self.worker_id_shift) | self.sequence 56 | return new_id 57 | 58 | 59 | if __name__ == "__main__": 60 | id_generator = SnowflakeIDGenerator(data_center_id=0, worker_id=0) 61 | 62 | for _ in range(6): 63 | print(id_generator.get_id()) 64 | -------------------------------------------------------------------------------- /src/app/utils/url_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | from collections.abc import Callable 4 | from typing import TypeVar, ParamSpec 5 | from pymongo.errors import ServerSelectionTimeoutError 6 | 7 | P = ParamSpec("P") 8 | R = TypeVar("R") 9 | 10 | 11 | def retry(retry_count: int = 3, wait_seconds: int = 2): 12 | def decorator(func: Callable[P, R]): 13 | @wraps(func) 14 | def wrapper(*args, **kwargs): 15 | nonlocal retry_count 16 | attempts = 0 17 | while attempts < retry_count: 18 | try: 19 | return func(*args, **kwargs) 20 | except ServerSelectionTimeoutError as e: 21 | print(f"Attempt {attempts + 1} failed with error: {e}") 22 | time.sleep(wait_seconds) 23 | attempts += 1 24 | return func(*args, **kwargs) 25 | return wrapper 26 | return decorator 27 | -------------------------------------------------------------------------------- /src/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = "-ra" 3 | asyncio_mode = "auto" 4 | asyncio_default_fixture_loop_scope = "function" 5 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | bcrypt==4.2.1 2 | fastapi==0.115.8 3 | hiredis==3.1.0 4 | httptools==0.6.4 5 | httpx==0.28.1 6 | motor==3.7.0 7 | pip==25.0.1 8 | pydantic-settings==2.8.0 9 | PyJWT==2.10.1 10 | pytest-asyncio==0.25.3 11 | PyYAML==6.0.2 12 | redis==5.2.1 13 | uvicorn==0.34.0 14 | uvloop==0.21.0 15 | watchfiles==1.0.4 16 | websockets==15.0 17 | wheel==0.45.1 18 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/binbjz/url_shortener/d3d46a6e4655905750e6470b5984f45a493b4454/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/test_security.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import jwt 3 | from datetime import datetime, timedelta, timezone 4 | from app.core.config import settings 5 | from app.core.security import (encrypt_password, 6 | verify_password, 7 | create_access_token, 8 | verify_access_token) 9 | 10 | 11 | def test_encrypt_password(): 12 | password = "test_password" 13 | encrypted_password = encrypt_password(password) 14 | assert password != encrypted_password 15 | assert len(encrypted_password) > len(password) 16 | 17 | 18 | def test_verify_password(): 19 | password = "test_password" 20 | wrong_password = "wrong_password" 21 | encrypted_password = encrypt_password(password) 22 | assert verify_password(password, encrypted_password) is True 23 | assert verify_password(wrong_password, encrypted_password) is False 24 | 25 | 26 | def test_verify_access_token(): 27 | user_data = {"sub": "test_user"} 28 | token = create_access_token(user_data) 29 | decoded_data = verify_access_token(token, Exception("Token verification failed")) 30 | assert decoded_data["sub"] == user_data["sub"] 31 | with pytest.raises(Exception) as e: 32 | verify_access_token("wrong_token", Exception("Token verification failed")) 33 | assert str(e.value) == "Token verification failed" 34 | 35 | 36 | def test_create_access_token_positive(): 37 | user_data = {"sub": "test_user"} 38 | token = create_access_token(user_data) 39 | 40 | assert isinstance(token, str) 41 | 42 | decoded_data = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm], 43 | options={"verify_signature": False}) 44 | assert decoded_data["sub"] == user_data["sub"] 45 | 46 | now = datetime.now(timezone.utc) 47 | exp = datetime.fromtimestamp(decoded_data["exp"], timezone.utc) 48 | delta = exp - now 49 | assert delta < timedelta(minutes=settings.access_token_expire_minutes + 1) 50 | assert delta > timedelta(minutes=settings.access_token_expire_minutes - 1) 51 | 52 | 53 | def test_create_access_token_negative(): 54 | with pytest.raises(Exception): 55 | invalid_data = "not_a_dict" 56 | create_access_token(invalid_data) 57 | 58 | -------------------------------------------------------------------------------- /src/tests/test_url.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytest_asyncio 3 | from httpx import AsyncClient 4 | 5 | 6 | @pytest.fixture(scope="function", autouse=True) 7 | def setup_mongodb(): 8 | from app.db.url_mongodb import MongoDB 9 | from app.core.config import settings 10 | from motor.motor_asyncio import AsyncIOMotorClient 11 | 12 | MongoDB.client = AsyncIOMotorClient(settings.database_url) 13 | yield 14 | MongoDB.client.close() 15 | 16 | 17 | @pytest_asyncio.fixture 18 | async def client() -> AsyncClient: 19 | async with AsyncClient(base_url="http://127.0.0.1:8000/api/v1/") as async_client: 20 | yield async_client 21 | 22 | 23 | async def create_short_url(client: AsyncClient, url: str = "https://example.com") -> str: 24 | response = await client.post("/shorten/", json={"long_url": url}) 25 | assert response.status_code == 200, "Unexpected status code" 26 | assert "short_url" in response.json(), "short_url not in response" 27 | return response.json()["short_url"] 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_create_short_url(client: AsyncClient): 32 | await create_short_url(client) 33 | 34 | 35 | @pytest.mark.asyncio 36 | async def test_create_short_urls_batch(client: AsyncClient): 37 | response = await client.post("/shorten/batch/", 38 | json=[{"long_url": "https://example.com"}, 39 | {"long_url": "https://example.org"}]) 40 | assert response.status_code == 200, "Unexpected status code" 41 | assert isinstance(response.json(), list), "Response is not a list" 42 | assert "short_url" in response.json()[0], "short_url not in first item of response" 43 | 44 | 45 | @pytest.mark.asyncio 46 | async def test_create_short_urls_concurrently(client: AsyncClient): 47 | response = await client.post("/shorten/batch_conc/", json=[ 48 | {"long_url": "https://example.com"}, 49 | {"long_url": "https://example.org"} 50 | ]) 51 | assert response.status_code == 200, f"Unexpected status code: {response.status_code}" 52 | data = response.json() 53 | assert isinstance(data, list), "Response is not a list" 54 | assert all("short_url" in item and "long_url" in item for item in data), "Invalid response structure" 55 | 56 | 57 | @pytest.mark.asyncio 58 | async def test_redirect_to_long_url(client: AsyncClient): 59 | short_url = await create_short_url(client) 60 | response = await client.get(f"/{short_url}") 61 | assert response.status_code == 307, "Unexpected status code" 62 | 63 | 64 | @pytest.mark.asyncio 65 | async def test_delete_short_url(client: AsyncClient): 66 | short_url = await create_short_url(client) 67 | response = await client.delete(f"/urls/delete/{short_url}") 68 | assert response.status_code == 200, "Unexpected status code" 69 | assert response.json() == {"detail": "URL deleted successfully."}, "Unexpected response" 70 | 71 | 72 | @pytest.mark.asyncio 73 | async def test_redirect_to_nonexistent_short_url(client: AsyncClient): 74 | response = await client.get("/nonexistent_short_url") 75 | assert response.status_code == 404, "Unexpected status code" 76 | 77 | 78 | @pytest.mark.asyncio 79 | async def test_delete_nonexistent_short_url(client: AsyncClient): 80 | response = await client.delete("/urls/delete/nonexistent_short_url") 81 | assert response.status_code == 404, "Unexpected status code" 82 | 83 | 84 | @pytest.mark.asyncio 85 | async def test_create_short_urls_concurrently_with_empty_list(client: AsyncClient): 86 | response = await client.post("/shorten/batch_conc/", json=[]) 87 | assert response.status_code == 400, "Unexpected status code" 88 | -------------------------------------------------------------------------------- /src/tests/test_url_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from app.services.url_service import URLShortener 3 | 4 | 5 | @pytest.fixture(scope="module", autouse=True) 6 | def shortener(): 7 | return URLShortener() 8 | 9 | 10 | def test_encode(shortener): 11 | assert shortener.encode(1) == "1" 12 | assert shortener.encode(62) == "10" 13 | unique_id = 123456789 14 | encoded_id = shortener.encode(unique_id) 15 | assert isinstance(encoded_id, str) 16 | assert encoded_id.lower() == '8M0kX'.lower() 17 | 18 | 19 | def test_add_url(shortener): 20 | long_url = "https://example.com" 21 | short_url = shortener.add_url(long_url) 22 | assert short_url.startswith("http://short.url/") 23 | assert shortener.add_url(long_url) == short_url 24 | 25 | 26 | def test_get_long_url(shortener): 27 | long_url = "https://example.com" 28 | short_url = shortener.add_url(long_url) 29 | assert shortener.get_long_url(short_url) == long_url 30 | assert shortener.get_long_url("http://short.url/nonexistent") is None 31 | 32 | 33 | def test_add_urls_concurrently(shortener): 34 | urls = ["https://example1.com", "https://example2.com"] 35 | short_urls = shortener.add_urls_concurrently(urls, max_workers=2) 36 | assert len(short_urls) == len(urls) 37 | for short_url in short_urls: 38 | assert short_url.startswith("http://short.url/") 39 | assert len(set(short_urls)) == len(urls) 40 | 41 | 42 | def test_encode_negative(shortener): 43 | with pytest.raises(ValueError): 44 | shortener.encode(-1) 45 | 46 | 47 | def test_encode_zero(shortener): 48 | with pytest.raises(ValueError): 49 | shortener.encode(0) 50 | 51 | 52 | def test_add_url_empty_string(shortener): 53 | with pytest.raises(ValueError): 54 | shortener.add_url("") 55 | 56 | 57 | def test_get_long_url_invalid_format(shortener): 58 | assert shortener.get_long_url("http://invalid.url") is None 59 | 60 | 61 | def test_add_urls_empty_list(shortener): 62 | assert shortener.add_urls([]) == [] 63 | assert shortener.add_urls_concurrently([], max_workers=2) == [] 64 | 65 | 66 | def test_add_urls_concurrently_large_number_of_urls(shortener): 67 | urls = ["https://example.com"] * 10000 68 | short_urls = shortener.add_urls_concurrently(urls, max_workers=10) 69 | assert len(short_urls) == len(urls) 70 | unique_short_urls = set(short_urls) 71 | assert len(unique_short_urls) == 1 72 | --------------------------------------------------------------------------------