diff --git a/py-kivy/pos_system/auth.py b/py-kivy/pos_system/auth.py index cb221ae..6599320 100644 --- a/py-kivy/pos_system/auth.py +++ b/py-kivy/pos_system/auth.py @@ -1,11 +1,18 @@ -from fastapi.security import OAuth2PasswordBearer -from passlib.context import CryptContext from datetime import datetime, timedelta +from typing import Optional +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt -from .database import get_db +from passlib.context import CryptContext +from pydantic import ValidationError +from bson import ObjectId -# JWT Configuration -SECRET_KEY = "YOUR_SECRET_KEY" +from .database import get_db +from .models import TokenData, User, UserInDB + +# to get a string like this run: +# openssl rand -hex 32 +SECRET_KEY = "YOUR_SECRET_KEY_HERE" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 @@ -21,11 +28,63 @@ def get_password_hash(password): return pwd_context.hash(password) -def create_access_token(data: dict): +def get_user(db, username: str) -> Optional[UserInDB]: + user_dict = db.users.find_one({"username": username}) + if user_dict: + return UserInDB(**user_dict) + + +def authenticate_user(db, username: str, password: str) -> Optional[UserInDB]: + user = get_user(db, username) + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt -# Add other authentication-related functions here + +async def get_current_user(token: str = Depends(oauth2_scheme)): + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except (JWTError, ValidationError): + raise credentials_exception + + db = get_db() + user = get_user(db, username=token_data.username) + if user is None: + raise credentials_exception + return user + + +async def get_current_active_user(current_user: User = Depends(get_current_user)): + if not current_user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +def create_user(db, user: UserInDB): + user_dict = user.model_dump(exclude={"id"}) + user_dict["_id"] = ObjectId() + result = db.users.insert_one(user_dict) + new_user = db.users.find_one({"_id": result.inserted_id}) + return UserInDB(**new_user) diff --git a/py-kivy/pos_system/models.py b/py-kivy/pos_system/models.py index d1fc49a..531dd6a 100644 --- a/py-kivy/pos_system/models.py +++ b/py-kivy/pos_system/models.py @@ -1,26 +1,18 @@ -from pydantic import BaseModel, Field, EmailStr, validator +from pydantic import BaseModel, Field, EmailStr, field_validator, validator, ConfigDict from typing import List, Optional from datetime import datetime from bson import ObjectId -from pydantic import ConfigDict -from typing import Any, ClassVar -class PyObjectId(ObjectId): - @classmethod - def __get_pydantic_core_schema__( - cls, _source_type: Any, _handler: Any - ) -> Any: - from pydantic_core import core_schema - return core_schema.json_or_python_schema( - python_schema=core_schema.is_instance_schema(ObjectId), - json_schema=core_schema.StringSchema(), - serialization=core_schema.to_string_ser_schema(), - ) +def validate_object_id(value: str) -> ObjectId: + if not ObjectId.is_valid(value): + raise ValueError("Invalid ObjectId") + return ObjectId(value) class MongoBaseModel(BaseModel): - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + id: Optional[str] = Field( + default_factory=lambda: str(ObjectId()), alias="_id") model_config = ConfigDict( populate_by_name=True, @@ -28,6 +20,12 @@ class MongoBaseModel(BaseModel): json_encoders={ObjectId: str} ) + # @field_validator("id", pre=True) + # def validate_id(cls, v): + # if isinstance(v, ObjectId): + # return str(v) + # return v + class Item(MongoBaseModel): name: str @@ -38,13 +36,17 @@ class Item(MongoBaseModel): class OrderItem(BaseModel): - item_id: PyObjectId + item_id: str quantity: int price_at_order: float + # @field_validator("item_id") + # def validate_item_id(cls, v): + # return validate_object_id(v) + class Order(MongoBaseModel): - user_id: PyObjectId + user_id: str items: List[OrderItem] total_amount: float payment_method: Optional[str] = None @@ -55,22 +57,26 @@ class Order(MongoBaseModel): discount_applied: Optional[float] = None notes: Optional[str] = None - @validator('order_status') - def valid_order_status(cls, v): - allowed_statuses = ["created", "processing", - "shipped", "delivered", "cancelled"] - if v not in allowed_statuses: - raise ValueError(f"Invalid order status. Must be one of: { - ', '.join(allowed_statuses)}") - return v + # @validator("user_id") + # def validate_user_id(cls, v): + # return validate_object_id(v) - @validator('payment_status') - def valid_payment_status(cls, v): - allowed_statuses = ["pending", "paid", "refunded", "failed"] - if v not in allowed_statuses: - raise ValueError(f"Invalid payment status. Must be one of: { - ', '.join(allowed_statuses)}") - return v + # @validator("order_status") + # def valid_order_status(cls, v): + # allowed_statuses = ["created", "processing", + # "shipped", "delivered", "cancelled"] + # if v not in allowed_statuses: + # raise ValueError(f"Invalid order status. Must be one of: { + # ', '.join(allowed_statuses)}") + # return v + + # @validator("payment_status") + # def valid_payment_status(cls, v): + # allowed_statuses = ["pending", "paid", "refunded", "failed"] + # if v not in allowed_statuses: + # raise ValueError(f"Invalid payment status. Must be one of: { + # ', '.join(allowed_statuses)}") + # return v class UserBase(BaseModel): diff --git a/py-kivy/pos_system/models2.py b/py-kivy/pos_system/models2.py new file mode 100644 index 0000000..4af09ae --- /dev/null +++ b/py-kivy/pos_system/models2.py @@ -0,0 +1,102 @@ +from pydantic import BaseModel, Field, EmailStr, field_validator +from typing import List, Optional +from datetime import datetime +from bson import ObjectId +from pydantic import ConfigDict +from typing import Any, ClassVar + + +class PyObjectId(ObjectId): + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Any + ) -> Any: + from pydantic_core import core_schema + return core_schema.json_or_python_schema( + python_schema=core_schema.is_instance_schema(ObjectId), + json_schema=core_schema.StringSchema(), + serialization=core_schema.to_string_ser_schema(), + ) + + +class MongoBaseModel(BaseModel): + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + json_encoders={ObjectId: str} + ) + + +class Item(MongoBaseModel): + name: str + price: float + quantity: int + unit: str + related_items: List[str] = [] + + +class OrderItem(BaseModel): + item_id: PyObjectId + quantity: int + price_at_order: float + + +class Order(MongoBaseModel): + user_id: PyObjectId + items: List[OrderItem] + total_amount: float + payment_method: Optional[str] = None + payment_status: str = "pending" + order_status: str = "created" + created_at: datetime = Field(default_factory=datetime.now(datetime.UTC)) + updated_at: Optional[datetime] = None + discount_applied: Optional[float] = None + notes: Optional[str] = None + + @field_validator('order_status') + def valid_order_status(cls, v): + allowed_statuses = ["created", "processing", + "shipped", "delivered", "cancelled"] + if v not in allowed_statuses: + raise ValueError(f"Invalid order status. Must be one of: { + ', '.join(allowed_statuses)}") + return v + + @field_validator('payment_status') + def valid_payment_status(cls, v): + allowed_statuses = ["pending", "paid", "refunded", "failed"] + if v not in allowed_statuses: + raise ValueError(f"Invalid payment status. Must be one of: { + ', '.join(allowed_statuses)}") + return v + + +class UserBase(BaseModel): + username: str + email: EmailStr + full_name: str + is_active: bool = True + is_superuser: bool = False + + +class UserCreate(UserBase): + password: str + + +class UserInDB(UserBase, MongoBaseModel): + hashed_password: str + + +class User(UserBase, MongoBaseModel): + pass + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: Optional[str] = None