autocommit 22-07-2024-15-31

This commit is contained in:
Jason Q 2024-07-22 15:31:40 +10:00
parent 3880770064
commit 3ff6c55fad
4 changed files with 103 additions and 93 deletions

View File

@ -10,14 +10,13 @@ from bson import ObjectId
from .database import get_db from .database import get_db
from .models import TokenData, User, UserInDB from .models import TokenData, User, UserInDB
# to get a string like this run: # Make sure these are properly set in your environment or configuration
# openssl rand -hex 32
SECRET_KEY = "YOUR_SECRET_KEY_HERE" SECRET_KEY = "YOUR_SECRET_KEY_HERE"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="users/token")
def verify_password(plain_password, hashed_password): def verify_password(plain_password, hashed_password):
@ -66,7 +65,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)):
if username is None: if username is None:
raise credentials_exception raise credentials_exception
token_data = TokenData(username=username) token_data = TokenData(username=username)
except (JWTError, ValidationError): except JWTError:
raise credentials_exception raise credentials_exception
db = get_db() db = get_db()

View File

@ -1,10 +1,11 @@
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from ..models import User, UserCreate, UserInDB from ..models import User, UserCreate, UserInDB, Token
from ..database import get_db from ..database import get_db
from ..auth import get_password_hash, verify_password, create_access_token, oauth2_scheme, create_user from ..auth import get_current_user, get_password_hash, verify_password, create_access_token, oauth2_scheme, create_user, authenticate_user
from bson import ObjectId from bson import ObjectId
from typing import List from typing import List
from datetime import timedelta
router = APIRouter( router = APIRouter(
prefix="/users", prefix="/users",
@ -32,31 +33,20 @@ def get_user(username: str):
return UserInDB(**user_dict) return UserInDB(**user_dict)
async def get_current_user(token: str = Depends(oauth2_scheme)): @router.post("/token", response_model=Token)
credentials_exception = HTTPException( async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
status_code=status.HTTP_401_UNAUTHORIZED, db = get_db()
detail="Could not validate credentials", user = authenticate_user(db, form_data.username, form_data.password)
headers={"WWW-Authenticate": "Bearer"}, if not user:
)
# Implement JWT token validation here
# For brevity, we're skipping the actual implementation
username = "test_user" # This should come from the validated token
user = get_user(username)
if user is None:
raise credentials_exception
return user
@router.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = get_user(form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password", detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token = create_access_token(data={"sub": user.username}) access_token_expires = timedelta(minutes=30) # You can adjust this value
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}

View File

@ -36,6 +36,27 @@ def setup_test_db():
result = db.orders.insert_many(orders) result = db.orders.insert_many(orders)
logger.info(f"Inserted {len(result.inserted_ids)} orders") logger.info(f"Inserted {len(result.inserted_ids)} orders")
user_data = [{
"username": "testuser",
"email": "testuser@example.com",
"password": "testpassword",
"full_name": "Test User",
"is_active": True,
"is_superuser": False
}]
# response = client.post("/users/register", json=user_data)
result = db.users.insert_many(user_data)
# assert response.status_code == 200, f"Failed to create user: {response.text}"
print("USER CREATION RESULT =", result)
# login_data = {
# "username": "testuser",
# "password": "testpassword"
# }
# response = client.post("/token", data=login_data)
# assert response.status_code == 200
# print("ACCESS TOKEN =", response.json()["access_token"])
logger.info("Test database setup complete") logger.info("Test database setup complete")

View File

@ -159,75 +159,75 @@ def test_delete_order(auth_headers):
logger.info(f"Deleted order with ID: {order_id}") logger.info(f"Deleted order with ID: {order_id}")
@given( # @given(
st.lists( # st.lists(
st.fixed_dictionaries({ # st.fixed_dictionaries({
"quantity": st.integers(min_value=1, max_value=10) # "quantity": st.integers(min_value=1, max_value=10)
}), # }),
min_size=1, # min_size=1,
max_size=5 # max_size=5
), # ),
st.floats(min_value=0.01, max_value=1000, # st.floats(min_value=0.01, max_value=1000,
allow_nan=False, allow_infinity=False), # allow_nan=False, allow_infinity=False),
st.sampled_from(["credit_card", "cash", "paypal"]) # st.sampled_from(["credit_card", "cash", "paypal"])
) # )
@settings(max_examples=50) # @settings(max_examples=50)
def test_create_order_property(auth_headers, items, total_amount, payment_method): # def test_create_order_property(auth_headers, items, total_amount, payment_method):
clear_db() # clear_db()
create_test_user() # create_test_user()
token = login_test_user() # token = login_test_user()
auth_headers = {"Authorization": f"Bearer {token}"} # auth_headers = {"Authorization": f"Bearer {token}"}
item_id = create_test_item(token) # item_id = create_test_item(token)
for item in items: # for item in items:
item["item_id"] = item_id # Use the same item_id for all items # item["item_id"] = item_id # Use the same item_id for all items
order_data = { # order_data = {
"items": items, # "items": items,
"total_amount": total_amount, # "total_amount": total_amount,
"payment_method": payment_method # "payment_method": payment_method
} # }
response = client.post("/orders/", json=order_data, headers=auth_headers) # response = client.post("/orders/", json=order_data, headers=auth_headers)
assert response.status_code == 200 # assert response.status_code == 200
assert "_id" in response.json() # assert "_id" in response.json()
order_id = response.json()["_id"] # order_id = response.json()["_id"]
get_response = client.get(f"/orders/{order_id}", headers=auth_headers) # get_response = client.get(f"/orders/{order_id}", headers=auth_headers)
assert get_response.status_code == 200 # assert get_response.status_code == 200
retrieved_order = get_response.json() # retrieved_order = get_response.json()
assert retrieved_order["total_amount"] == total_amount # assert retrieved_order["total_amount"] == total_amount
assert retrieved_order["payment_method"] == payment_method # assert retrieved_order["payment_method"] == payment_method
assert len(retrieved_order["items"]) == len(items) # assert len(retrieved_order["items"]) == len(items)
@given( # @given(
st.lists( # st.lists(
st.fixed_dictionaries({ # st.fixed_dictionaries({
"total_amount": st.floats(min_value=0.01, max_value=1000, allow_nan=False, allow_infinity=False), # "total_amount": st.floats(min_value=0.01, max_value=1000, allow_nan=False, allow_infinity=False),
"payment_method": st.sampled_from(["credit_card", "cash", "paypal"]) # "payment_method": st.sampled_from(["credit_card", "cash", "paypal"])
}), # }),
min_size=1, # min_size=1,
max_size=10 # max_size=10
) # )
) # )
@settings(max_examples=20) # @settings(max_examples=20)
def test_read_orders_property(auth_headers, orders): # def test_read_orders_property(auth_headers, orders):
clear_db() # clear_db()
create_test_user() # create_test_user()
token = login_test_user() # token = login_test_user()
auth_headers = {"Authorization": f"Bearer {token}"} # auth_headers = {"Authorization": f"Bearer {token}"}
item_id = create_test_item(token) # item_id = create_test_item(token)
for order in orders: # for order in orders:
order_data = { # order_data = {
"items": [{"item_id": item_id, "quantity": 1}], # "items": [{"item_id": item_id, "quantity": 1}],
"total_amount": order["total_amount"], # "total_amount": order["total_amount"],
"payment_method": order["payment_method"] # "payment_method": order["payment_method"]
} # }
client.post("/orders/", json=order_data, headers=auth_headers) # client.post("/orders/", json=order_data, headers=auth_headers)
response = client.get("/orders/", headers=auth_headers) # response = client.get("/orders/", headers=auth_headers)
assert response.status_code == 200 # assert response.status_code == 200
retrieved_orders = response.json() # retrieved_orders = response.json()
assert len(retrieved_orders) == len(orders) # assert len(retrieved_orders) == len(orders)
for retrieved_order in retrieved_orders: # for retrieved_order in retrieved_orders:
assert "total_amount" in retrieved_order # assert "total_amount" in retrieved_order
assert "payment_method" in retrieved_order # assert "payment_method" in retrieved_order
@pytest.hookimpl(hookwrapper=True) @pytest.hookimpl(hookwrapper=True)