autocommit 13-07-2024-10-35

This commit is contained in:
Jasen Qin 2024-07-13 10:35:56 +10:00
parent 7eb6f8f5e8
commit e21e01e009
3 changed files with 207 additions and 40 deletions

View File

@ -1,11 +1,18 @@
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from .database import get_db
from passlib.context import CryptContext
from pydantic import ValidationError
from bson import ObjectId
# JWT Configuration
SECRET_KEY = "YOUR_SECRET_KEY"
from .database import get_db
from .models import TokenData, User, UserInDB
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "YOUR_SECRET_KEY_HERE"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
@ -21,11 +28,63 @@ def get_password_hash(password):
return pwd_context.hash(password)
def create_access_token(data: dict):
def get_user(db, username: str) -> Optional[UserInDB]:
user_dict = db.users.find_one({"username": username})
if user_dict:
return UserInDB(**user_dict)
def authenticate_user(db, username: str, password: str) -> Optional[UserInDB]:
user = get_user(db, username)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# Add other authentication-related functions here
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except (JWTError, ValidationError):
raise credentials_exception
db = get_db()
user = get_user(db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def create_user(db, user: UserInDB):
user_dict = user.model_dump(exclude={"id"})
user_dict["_id"] = ObjectId()
result = db.users.insert_one(user_dict)
new_user = db.users.find_one({"_id": result.inserted_id})
return UserInDB(**new_user)

View File

@ -1,26 +1,18 @@
from pydantic import BaseModel, Field, EmailStr, validator
from pydantic import BaseModel, Field, EmailStr, field_validator, validator, ConfigDict
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
from pydantic import ConfigDict
from typing import Any, ClassVar
class PyObjectId(ObjectId):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> Any:
from pydantic_core import core_schema
return core_schema.json_or_python_schema(
python_schema=core_schema.is_instance_schema(ObjectId),
json_schema=core_schema.StringSchema(),
serialization=core_schema.to_string_ser_schema(),
)
def validate_object_id(value: str) -> ObjectId:
if not ObjectId.is_valid(value):
raise ValueError("Invalid ObjectId")
return ObjectId(value)
class MongoBaseModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
id: Optional[str] = Field(
default_factory=lambda: str(ObjectId()), alias="_id")
model_config = ConfigDict(
populate_by_name=True,
@ -28,6 +20,12 @@ class MongoBaseModel(BaseModel):
json_encoders={ObjectId: str}
)
# @field_validator("id", pre=True)
# def validate_id(cls, v):
# if isinstance(v, ObjectId):
# return str(v)
# return v
class Item(MongoBaseModel):
name: str
@ -38,13 +36,17 @@ class Item(MongoBaseModel):
class OrderItem(BaseModel):
item_id: PyObjectId
item_id: str
quantity: int
price_at_order: float
# @field_validator("item_id")
# def validate_item_id(cls, v):
# return validate_object_id(v)
class Order(MongoBaseModel):
user_id: PyObjectId
user_id: str
items: List[OrderItem]
total_amount: float
payment_method: Optional[str] = None
@ -55,22 +57,26 @@ class Order(MongoBaseModel):
discount_applied: Optional[float] = None
notes: Optional[str] = None
@validator('order_status')
def valid_order_status(cls, v):
allowed_statuses = ["created", "processing",
"shipped", "delivered", "cancelled"]
if v not in allowed_statuses:
raise ValueError(f"Invalid order status. Must be one of: {
', '.join(allowed_statuses)}")
return v
# @validator("user_id")
# def validate_user_id(cls, v):
# return validate_object_id(v)
@validator('payment_status')
def valid_payment_status(cls, v):
allowed_statuses = ["pending", "paid", "refunded", "failed"]
if v not in allowed_statuses:
raise ValueError(f"Invalid payment status. Must be one of: {
', '.join(allowed_statuses)}")
return v
# @validator("order_status")
# def valid_order_status(cls, v):
# allowed_statuses = ["created", "processing",
# "shipped", "delivered", "cancelled"]
# if v not in allowed_statuses:
# raise ValueError(f"Invalid order status. Must be one of: {
# ', '.join(allowed_statuses)}")
# return v
# @validator("payment_status")
# def valid_payment_status(cls, v):
# allowed_statuses = ["pending", "paid", "refunded", "failed"]
# if v not in allowed_statuses:
# raise ValueError(f"Invalid payment status. Must be one of: {
# ', '.join(allowed_statuses)}")
# return v
class UserBase(BaseModel):

View File

@ -0,0 +1,102 @@
from pydantic import BaseModel, Field, EmailStr, field_validator
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
from pydantic import ConfigDict
from typing import Any, ClassVar
class PyObjectId(ObjectId):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> Any:
from pydantic_core import core_schema
return core_schema.json_or_python_schema(
python_schema=core_schema.is_instance_schema(ObjectId),
json_schema=core_schema.StringSchema(),
serialization=core_schema.to_string_ser_schema(),
)
class MongoBaseModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
json_encoders={ObjectId: str}
)
class Item(MongoBaseModel):
name: str
price: float
quantity: int
unit: str
related_items: List[str] = []
class OrderItem(BaseModel):
item_id: PyObjectId
quantity: int
price_at_order: float
class Order(MongoBaseModel):
user_id: PyObjectId
items: List[OrderItem]
total_amount: float
payment_method: Optional[str] = None
payment_status: str = "pending"
order_status: str = "created"
created_at: datetime = Field(default_factory=datetime.now(datetime.UTC))
updated_at: Optional[datetime] = None
discount_applied: Optional[float] = None
notes: Optional[str] = None
@field_validator('order_status')
def valid_order_status(cls, v):
allowed_statuses = ["created", "processing",
"shipped", "delivered", "cancelled"]
if v not in allowed_statuses:
raise ValueError(f"Invalid order status. Must be one of: {
', '.join(allowed_statuses)}")
return v
@field_validator('payment_status')
def valid_payment_status(cls, v):
allowed_statuses = ["pending", "paid", "refunded", "failed"]
if v not in allowed_statuses:
raise ValueError(f"Invalid payment status. Must be one of: {
', '.join(allowed_statuses)}")
return v
class UserBase(BaseModel):
username: str
email: EmailStr
full_name: str
is_active: bool = True
is_superuser: bool = False
class UserCreate(UserBase):
password: str
class UserInDB(UserBase, MongoBaseModel):
hashed_password: str
class User(UserBase, MongoBaseModel):
pass
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None