autocommit 13-07-2024-10-35

This commit is contained in:
Jasen Qin 2024-07-13 10:35:56 +10:00
parent 7eb6f8f5e8
commit e21e01e009
3 changed files with 207 additions and 40 deletions

View File

@ -1,11 +1,18 @@
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt from jose import JWTError, jwt
from .database import get_db from passlib.context import CryptContext
from pydantic import ValidationError
from bson import ObjectId
# JWT Configuration from .database import get_db
SECRET_KEY = "YOUR_SECRET_KEY" from .models import TokenData, User, UserInDB
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "YOUR_SECRET_KEY_HERE"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
@ -21,11 +28,63 @@ def get_password_hash(password):
return pwd_context.hash(password) return pwd_context.hash(password)
def create_access_token(data: dict): def get_user(db, username: str) -> Optional[UserInDB]:
user_dict = db.users.find_one({"username": username})
if user_dict:
return UserInDB(**user_dict)
def authenticate_user(db, username: str, password: str) -> Optional[UserInDB]:
user = get_user(db, username)
if not user:
return None
if not verify_password(password, user.hashed_password):
return None
return user
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy() to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
# Add other authentication-related functions here
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except (JWTError, ValidationError):
raise credentials_exception
db = get_db()
user = get_user(db, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def get_current_active_user(current_user: User = Depends(get_current_user)):
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
def create_user(db, user: UserInDB):
user_dict = user.model_dump(exclude={"id"})
user_dict["_id"] = ObjectId()
result = db.users.insert_one(user_dict)
new_user = db.users.find_one({"_id": result.inserted_id})
return UserInDB(**new_user)

View File

@ -1,26 +1,18 @@
from pydantic import BaseModel, Field, EmailStr, validator from pydantic import BaseModel, Field, EmailStr, field_validator, validator, ConfigDict
from typing import List, Optional from typing import List, Optional
from datetime import datetime from datetime import datetime
from bson import ObjectId from bson import ObjectId
from pydantic import ConfigDict
from typing import Any, ClassVar
class PyObjectId(ObjectId): def validate_object_id(value: str) -> ObjectId:
@classmethod if not ObjectId.is_valid(value):
def __get_pydantic_core_schema__( raise ValueError("Invalid ObjectId")
cls, _source_type: Any, _handler: Any return ObjectId(value)
) -> Any:
from pydantic_core import core_schema
return core_schema.json_or_python_schema(
python_schema=core_schema.is_instance_schema(ObjectId),
json_schema=core_schema.StringSchema(),
serialization=core_schema.to_string_ser_schema(),
)
class MongoBaseModel(BaseModel): class MongoBaseModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") id: Optional[str] = Field(
default_factory=lambda: str(ObjectId()), alias="_id")
model_config = ConfigDict( model_config = ConfigDict(
populate_by_name=True, populate_by_name=True,
@ -28,6 +20,12 @@ class MongoBaseModel(BaseModel):
json_encoders={ObjectId: str} json_encoders={ObjectId: str}
) )
# @field_validator("id", pre=True)
# def validate_id(cls, v):
# if isinstance(v, ObjectId):
# return str(v)
# return v
class Item(MongoBaseModel): class Item(MongoBaseModel):
name: str name: str
@ -38,13 +36,17 @@ class Item(MongoBaseModel):
class OrderItem(BaseModel): class OrderItem(BaseModel):
item_id: PyObjectId item_id: str
quantity: int quantity: int
price_at_order: float price_at_order: float
# @field_validator("item_id")
# def validate_item_id(cls, v):
# return validate_object_id(v)
class Order(MongoBaseModel): class Order(MongoBaseModel):
user_id: PyObjectId user_id: str
items: List[OrderItem] items: List[OrderItem]
total_amount: float total_amount: float
payment_method: Optional[str] = None payment_method: Optional[str] = None
@ -55,22 +57,26 @@ class Order(MongoBaseModel):
discount_applied: Optional[float] = None discount_applied: Optional[float] = None
notes: Optional[str] = None notes: Optional[str] = None
@validator('order_status') # @validator("user_id")
def valid_order_status(cls, v): # def validate_user_id(cls, v):
allowed_statuses = ["created", "processing", # return validate_object_id(v)
"shipped", "delivered", "cancelled"]
if v not in allowed_statuses:
raise ValueError(f"Invalid order status. Must be one of: {
', '.join(allowed_statuses)}")
return v
@validator('payment_status') # @validator("order_status")
def valid_payment_status(cls, v): # def valid_order_status(cls, v):
allowed_statuses = ["pending", "paid", "refunded", "failed"] # allowed_statuses = ["created", "processing",
if v not in allowed_statuses: # "shipped", "delivered", "cancelled"]
raise ValueError(f"Invalid payment status. Must be one of: { # if v not in allowed_statuses:
', '.join(allowed_statuses)}") # raise ValueError(f"Invalid order status. Must be one of: {
return v # ', '.join(allowed_statuses)}")
# return v
# @validator("payment_status")
# def valid_payment_status(cls, v):
# allowed_statuses = ["pending", "paid", "refunded", "failed"]
# if v not in allowed_statuses:
# raise ValueError(f"Invalid payment status. Must be one of: {
# ', '.join(allowed_statuses)}")
# return v
class UserBase(BaseModel): class UserBase(BaseModel):

View File

@ -0,0 +1,102 @@
from pydantic import BaseModel, Field, EmailStr, field_validator
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
from pydantic import ConfigDict
from typing import Any, ClassVar
class PyObjectId(ObjectId):
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Any
) -> Any:
from pydantic_core import core_schema
return core_schema.json_or_python_schema(
python_schema=core_schema.is_instance_schema(ObjectId),
json_schema=core_schema.StringSchema(),
serialization=core_schema.to_string_ser_schema(),
)
class MongoBaseModel(BaseModel):
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
json_encoders={ObjectId: str}
)
class Item(MongoBaseModel):
name: str
price: float
quantity: int
unit: str
related_items: List[str] = []
class OrderItem(BaseModel):
item_id: PyObjectId
quantity: int
price_at_order: float
class Order(MongoBaseModel):
user_id: PyObjectId
items: List[OrderItem]
total_amount: float
payment_method: Optional[str] = None
payment_status: str = "pending"
order_status: str = "created"
created_at: datetime = Field(default_factory=datetime.now(datetime.UTC))
updated_at: Optional[datetime] = None
discount_applied: Optional[float] = None
notes: Optional[str] = None
@field_validator('order_status')
def valid_order_status(cls, v):
allowed_statuses = ["created", "processing",
"shipped", "delivered", "cancelled"]
if v not in allowed_statuses:
raise ValueError(f"Invalid order status. Must be one of: {
', '.join(allowed_statuses)}")
return v
@field_validator('payment_status')
def valid_payment_status(cls, v):
allowed_statuses = ["pending", "paid", "refunded", "failed"]
if v not in allowed_statuses:
raise ValueError(f"Invalid payment status. Must be one of: {
', '.join(allowed_statuses)}")
return v
class UserBase(BaseModel):
username: str
email: EmailStr
full_name: str
is_active: bool = True
is_superuser: bool = False
class UserCreate(UserBase):
password: str
class UserInDB(UserBase, MongoBaseModel):
hashed_password: str
class User(UserBase, MongoBaseModel):
pass
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: Optional[str] = None