autocommit 13-07-2024-10-35
This commit is contained in:
parent
7eb6f8f5e8
commit
e21e01e009
|
|
@ -1,11 +1,18 @@
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from .database import get_db
|
from passlib.context import CryptContext
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
# JWT Configuration
|
from .database import get_db
|
||||||
SECRET_KEY = "YOUR_SECRET_KEY"
|
from .models import TokenData, User, UserInDB
|
||||||
|
|
||||||
|
# to get a string like this run:
|
||||||
|
# openssl rand -hex 32
|
||||||
|
SECRET_KEY = "YOUR_SECRET_KEY_HERE"
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
|
||||||
|
|
@ -21,11 +28,63 @@ def get_password_hash(password):
|
||||||
return pwd_context.hash(password)
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict):
|
def get_user(db, username: str) -> Optional[UserInDB]:
|
||||||
|
user_dict = db.users.find_one({"username": username})
|
||||||
|
if user_dict:
|
||||||
|
return UserInDB(**user_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(db, username: str, password: str) -> Optional[UserInDB]:
|
||||||
|
user = get_user(db, username)
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.hashed_password):
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||||
to_encode.update({"exp": expire})
|
to_encode.update({"exp": expire})
|
||||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
# Add other authentication-related functions here
|
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_scheme)):
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
token_data = TokenData(username=username)
|
||||||
|
except (JWTError, ValidationError):
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
db = get_db()
|
||||||
|
user = get_user(db, username=token_data.username)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_active_user(current_user: User = Depends(get_current_user)):
|
||||||
|
if not current_user.is_active:
|
||||||
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(db, user: UserInDB):
|
||||||
|
user_dict = user.model_dump(exclude={"id"})
|
||||||
|
user_dict["_id"] = ObjectId()
|
||||||
|
result = db.users.insert_one(user_dict)
|
||||||
|
new_user = db.users.find_one({"_id": result.inserted_id})
|
||||||
|
return UserInDB(**new_user)
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,18 @@
|
||||||
from pydantic import BaseModel, Field, EmailStr, validator
|
from pydantic import BaseModel, Field, EmailStr, field_validator, validator, ConfigDict
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from pydantic import ConfigDict
|
|
||||||
from typing import Any, ClassVar
|
|
||||||
|
|
||||||
|
|
||||||
class PyObjectId(ObjectId):
|
def validate_object_id(value: str) -> ObjectId:
|
||||||
@classmethod
|
if not ObjectId.is_valid(value):
|
||||||
def __get_pydantic_core_schema__(
|
raise ValueError("Invalid ObjectId")
|
||||||
cls, _source_type: Any, _handler: Any
|
return ObjectId(value)
|
||||||
) -> Any:
|
|
||||||
from pydantic_core import core_schema
|
|
||||||
return core_schema.json_or_python_schema(
|
|
||||||
python_schema=core_schema.is_instance_schema(ObjectId),
|
|
||||||
json_schema=core_schema.StringSchema(),
|
|
||||||
serialization=core_schema.to_string_ser_schema(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MongoBaseModel(BaseModel):
|
class MongoBaseModel(BaseModel):
|
||||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: str(ObjectId()), alias="_id")
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
populate_by_name=True,
|
populate_by_name=True,
|
||||||
|
|
@ -28,6 +20,12 @@ class MongoBaseModel(BaseModel):
|
||||||
json_encoders={ObjectId: str}
|
json_encoders={ObjectId: str}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# @field_validator("id", pre=True)
|
||||||
|
# def validate_id(cls, v):
|
||||||
|
# if isinstance(v, ObjectId):
|
||||||
|
# return str(v)
|
||||||
|
# return v
|
||||||
|
|
||||||
|
|
||||||
class Item(MongoBaseModel):
|
class Item(MongoBaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -38,13 +36,17 @@ class Item(MongoBaseModel):
|
||||||
|
|
||||||
|
|
||||||
class OrderItem(BaseModel):
|
class OrderItem(BaseModel):
|
||||||
item_id: PyObjectId
|
item_id: str
|
||||||
quantity: int
|
quantity: int
|
||||||
price_at_order: float
|
price_at_order: float
|
||||||
|
|
||||||
|
# @field_validator("item_id")
|
||||||
|
# def validate_item_id(cls, v):
|
||||||
|
# return validate_object_id(v)
|
||||||
|
|
||||||
|
|
||||||
class Order(MongoBaseModel):
|
class Order(MongoBaseModel):
|
||||||
user_id: PyObjectId
|
user_id: str
|
||||||
items: List[OrderItem]
|
items: List[OrderItem]
|
||||||
total_amount: float
|
total_amount: float
|
||||||
payment_method: Optional[str] = None
|
payment_method: Optional[str] = None
|
||||||
|
|
@ -55,22 +57,26 @@ class Order(MongoBaseModel):
|
||||||
discount_applied: Optional[float] = None
|
discount_applied: Optional[float] = None
|
||||||
notes: Optional[str] = None
|
notes: Optional[str] = None
|
||||||
|
|
||||||
@validator('order_status')
|
# @validator("user_id")
|
||||||
def valid_order_status(cls, v):
|
# def validate_user_id(cls, v):
|
||||||
allowed_statuses = ["created", "processing",
|
# return validate_object_id(v)
|
||||||
"shipped", "delivered", "cancelled"]
|
|
||||||
if v not in allowed_statuses:
|
|
||||||
raise ValueError(f"Invalid order status. Must be one of: {
|
|
||||||
', '.join(allowed_statuses)}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@validator('payment_status')
|
# @validator("order_status")
|
||||||
def valid_payment_status(cls, v):
|
# def valid_order_status(cls, v):
|
||||||
allowed_statuses = ["pending", "paid", "refunded", "failed"]
|
# allowed_statuses = ["created", "processing",
|
||||||
if v not in allowed_statuses:
|
# "shipped", "delivered", "cancelled"]
|
||||||
raise ValueError(f"Invalid payment status. Must be one of: {
|
# if v not in allowed_statuses:
|
||||||
', '.join(allowed_statuses)}")
|
# raise ValueError(f"Invalid order status. Must be one of: {
|
||||||
return v
|
# ', '.join(allowed_statuses)}")
|
||||||
|
# return v
|
||||||
|
|
||||||
|
# @validator("payment_status")
|
||||||
|
# def valid_payment_status(cls, v):
|
||||||
|
# allowed_statuses = ["pending", "paid", "refunded", "failed"]
|
||||||
|
# if v not in allowed_statuses:
|
||||||
|
# raise ValueError(f"Invalid payment status. Must be one of: {
|
||||||
|
# ', '.join(allowed_statuses)}")
|
||||||
|
# return v
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
from pydantic import BaseModel, Field, EmailStr, field_validator
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from bson import ObjectId
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
|
||||||
|
class PyObjectId(ObjectId):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: Any
|
||||||
|
) -> Any:
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
return core_schema.json_or_python_schema(
|
||||||
|
python_schema=core_schema.is_instance_schema(ObjectId),
|
||||||
|
json_schema=core_schema.StringSchema(),
|
||||||
|
serialization=core_schema.to_string_ser_schema(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MongoBaseModel(BaseModel):
|
||||||
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
json_encoders={ObjectId: str}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Item(MongoBaseModel):
|
||||||
|
name: str
|
||||||
|
price: float
|
||||||
|
quantity: int
|
||||||
|
unit: str
|
||||||
|
related_items: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
|
class OrderItem(BaseModel):
|
||||||
|
item_id: PyObjectId
|
||||||
|
quantity: int
|
||||||
|
price_at_order: float
|
||||||
|
|
||||||
|
|
||||||
|
class Order(MongoBaseModel):
|
||||||
|
user_id: PyObjectId
|
||||||
|
items: List[OrderItem]
|
||||||
|
total_amount: float
|
||||||
|
payment_method: Optional[str] = None
|
||||||
|
payment_status: str = "pending"
|
||||||
|
order_status: str = "created"
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now(datetime.UTC))
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
discount_applied: Optional[float] = None
|
||||||
|
notes: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator('order_status')
|
||||||
|
def valid_order_status(cls, v):
|
||||||
|
allowed_statuses = ["created", "processing",
|
||||||
|
"shipped", "delivered", "cancelled"]
|
||||||
|
if v not in allowed_statuses:
|
||||||
|
raise ValueError(f"Invalid order status. Must be one of: {
|
||||||
|
', '.join(allowed_statuses)}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('payment_status')
|
||||||
|
def valid_payment_status(cls, v):
|
||||||
|
allowed_statuses = ["pending", "paid", "refunded", "failed"]
|
||||||
|
if v not in allowed_statuses:
|
||||||
|
raise ValueError(f"Invalid payment status. Must be one of: {
|
||||||
|
', '.join(allowed_statuses)}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
username: str
|
||||||
|
email: EmailStr
|
||||||
|
full_name: str
|
||||||
|
is_active: bool = True
|
||||||
|
is_superuser: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(UserBase):
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserInDB(UserBase, MongoBaseModel):
|
||||||
|
hashed_password: str
|
||||||
|
|
||||||
|
|
||||||
|
class User(UserBase, MongoBaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
username: Optional[str] = None
|
||||||
Loading…
Reference in New Issue