Skip to content

Commit

Permalink
Merge pull request #394 from open-contracting/mypy-auth
Browse files Browse the repository at this point in the history
Fix Mypy errors and small fixes
  • Loading branch information
yolile authored Aug 24, 2024
2 parents bad7f28 + 737d907 commit 79b69c1
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
mypy --strict --show-error-codes app reportlab_mods.py > error-summary.txt
status=$?
cat error-summary.txt
[ $status -eq 0 ] || [ $(tail -n1 error-summary.txt | cut -d' ' -f2) -le 50 ]
[ $status -eq 0 ] || [ $(tail -n1 error-summary.txt | cut -d' ' -f2) -le 20 ]
61 changes: 26 additions & 35 deletions app/auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Any

import jwt
Expand Down Expand Up @@ -25,23 +26,31 @@ class JWTAuthorizationCredentials(BaseModel):
message: str


@functools.lru_cache
def get_keys() -> dict[str, JWK]:
return {
jwk["kid"]: jwk
for jwk in JWKS.model_validate(
requests.get(
f"https://cognito-idp.{app_settings.aws_region}.amazonaws.com/"
f"{app_settings.cognito_pool_id}/.well-known/jwks.json"
).json()
).keys
}


# https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
class JWTAuthorization(HTTPBearer):
"""
An extension of HTTPBearer authentication to verify JWT (JSON Web Tokens) with public keys.
This class loads and keeps track of public keys from an external source and verifies incoming tokens.
:param auto_error: If set to True, automatic error responses will be sent when request authentication fails.
Default is True.
"""

def __init__(self, auto_error: bool = True):
super().__init__(auto_error=auto_error)
self.kid_to_jwk: dict[str, JWK] | None = None

def load_keys(self) -> None:
if self.kid_to_jwk is None:
jwks = _get_public_keys()
self.kid_to_jwk = {jwk["kid"]: jwk for jwk in jwks.keys}
self.kid_to_jwk = get_keys()

def verify_jwk_token(self, jwt_credentials: JWTAuthorizationCredentials) -> bool:
"""
Expand All @@ -50,9 +59,16 @@ def verify_jwk_token(self, jwt_credentials: JWTAuthorizationCredentials) -> bool
:param jwt_credentials: JWT credentials extracted from the request.
:return: Returns True if the token is verified, False otherwise.
"""
self.load_keys()
try:
public_key = self.kid_to_jwk[jwt_credentials.header["kid"]]
kid = jwt_credentials.header["kid"]

# "If you receive a token with the correct issuer but a different kid, Amazon Cognito might have rotated
# the signing key. Refresh the cache from your user pool jwks_uri endpoint."
if kid not in self.kid_to_jwk:
get_keys.cache_clear()
self.kid_to_jwk = get_keys()

public_key = self.kid_to_jwk[kid]
except KeyError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
Expand All @@ -64,6 +80,7 @@ def verify_jwk_token(self, jwt_credentials: JWTAuthorizationCredentials) -> bool

obj = jwt.PyJWK(public_key)
alg_obj = obj.Algorithm
assert alg_obj
prepared_key = alg_obj.prepare_key(obj.key)

return alg_obj.verify(msg, prepared_key, sig)
Expand All @@ -75,8 +92,6 @@ async def __call__(self, request: Request) -> JWTAuthorizationCredentials:
:param request: Incoming request instance.
:return: JWT credentials if the token is verified.
"""
self.load_keys()

if credentials := await super().__call__(request):
if not credentials.scheme == "Bearer":
raise HTTPException(
Expand Down Expand Up @@ -120,27 +135,3 @@ async def __call__(self, request: Request) -> JWTAuthorizationCredentials:
status_code=status.HTTP_403_FORBIDDEN,
detail=_("Not authenticated"),
)


public_keys = None


def _get_public_keys() -> JWKS:
"""
Retrieves the public keys from the well-known JWKS (JSON Web Key Set) endpoint of Cognito.
The function caches the fetched keys in a global variable `public_keys` to avoid repetitive calls
to the endpoint.
:return: The parsed JWKS, an object which holds a list of keys.
"""
global public_keys
if public_keys is None:
public_keys = JWKS.model_validate(
# https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
requests.get(
f"https://cognito-idp.{app_settings.aws_region}.amazonaws.com/"
f"{app_settings.cognito_pool_id}/.well-known/jwks.json"
).json()
)
return public_keys
9 changes: 2 additions & 7 deletions app/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def respond_to_auth_challenge(
challenge_name: literals.ChallengeNameTypeType,
new_password: str = "",
mfa_code: str = "",
) -> type_defs.RespondToAuthChallengeResponseTypeDef | dict[str, str]:
) -> type_defs.RespondToAuthChallengeResponseTypeDef:
"""
Responds to the authentication challenge provided by Cognito.
Expand Down Expand Up @@ -156,7 +156,7 @@ def respond_to_auth_challenge(
Session=verify_response["Session"],
)
case "SOFTWARE_TOKEN_MFA":
challenge_response = self.cognito.respond_to_auth_challenge(
return self.cognito.respond_to_auth_challenge(
ClientId=app_settings.cognito_client_id,
ChallengeName=challenge_name,
ChallengeResponses={
Expand All @@ -166,11 +166,6 @@ def respond_to_auth_challenge(
},
Session=session,
)

return {
"access_token": challenge_response["AuthenticationResult"]["AccessToken"],
"refresh_token": challenge_response["AuthenticationResult"]["RefreshToken"],
}
case _:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
Expand Down
17 changes: 9 additions & 8 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def create(cls, session: Session, **data: Any) -> Self:
return obj

@classmethod
def create_from_object(cls, session: Session, obj: Any) -> Self:
def create_from_object(cls, session: Session, obj: Self) -> Self:
"""
Insert a new instance into the database.
Expand Down Expand Up @@ -529,7 +529,7 @@ class Award(AwardBase, ActiveRecordMixin, table=True):

# Relationships
applications: list["Application"] = Relationship(back_populates="award")
borrower: Borrower = Relationship(back_populates="awards")
borrower: Borrower | None = Relationship(back_populates="awards")

# Timestamps
created_at: datetime = Field(
Expand Down Expand Up @@ -653,8 +653,8 @@ class ApplicationBase(SQLModel):
archived_at: datetime | None = Field(sa_column=Column(DateTime(timezone=True)))

# Relationships
award_id: int | None = Field(foreign_key="award.id", index=True)
borrower_id: int | None = Field(foreign_key="borrower.id", index=True)
award_id: int = Field(foreign_key="award.id", index=True)
borrower_id: int = Field(foreign_key="borrower.id", index=True)
lender_id: int | None = Field(foreign_key="lender.id")
credit_product_id: int | None = Field(foreign_key="credit_product.id")

Expand All @@ -678,9 +678,10 @@ class Application(ApplicationPrivate, ActiveRecordMixin, table=True):
borrower_documents: list["BorrowerDocument"] = Relationship(back_populates="application")
award: Award = Relationship(back_populates="applications")
borrower: Borrower = Relationship(back_populates="applications")
lender: Lender | None = Relationship(back_populates="applications")
lender: Lender = Relationship(back_populates="applications")
messages: list["Message"] = Relationship(back_populates="application")
actions: list["ApplicationAction"] = Relationship(back_populates="application")
# no back_populates, because models.CreditProduct is used as a request and response format. :issue:`376`
credit_product: CreditProduct = Relationship()

@classmethod
Expand Down Expand Up @@ -946,7 +947,7 @@ class BorrowerDocument(BorrowerDocumentBase, ActiveRecordMixin, table=True):
file: bytes

# Relationships
application: Application | None = Relationship(back_populates="borrower_documents")
application: Application = Relationship(back_populates="borrower_documents")


class Message(SQLModel, ActiveRecordMixin, table=True):
Expand All @@ -960,7 +961,7 @@ class Message(SQLModel, ActiveRecordMixin, table=True):

# Relationships
application_id: int = Field(foreign_key="application.id")
application: Application | None = Relationship(back_populates="messages")
application: Application = Relationship(back_populates="messages")
lender_id: int | None = Field(default=None, foreign_key="lender.id")

# Timestamps
Expand Down Expand Up @@ -1036,7 +1037,7 @@ class ApplicationAction(SQLModel, ActiveRecordMixin, table=True):

# Relationships
application_id: int = Field(foreign_key="application.id")
application: Application | None = Relationship(back_populates="actions")
application: Application = Relationship(back_populates="actions")
user_id: int | None = Field(default=None, foreign_key="credere_user.id")
user: User | None = Relationship(back_populates="application_actions")

Expand Down
8 changes: 6 additions & 2 deletions app/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
class BasicUser(BaseModel):
username: str
name: str | None = None
password: str | None = None
temp_password: str | None = None
password: str
temp_password: str


class ResetPassword(BaseModel):
username: str


class SetupMFA(BaseModel):
Expand Down
18 changes: 9 additions & 9 deletions app/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def change_password(
and handles different scenarios such as new password requirement, MFA setup, and error handling.
"""
# This endpoint is only called for new users, to replace the generated password.
response = client.initiate_auth(payload.username, payload.temp_password)
if response["ChallengeName"] == "NEW_PASSWORD_REQUIRED":
response = client.respond_to_auth_challenge(
initiate_response = client.initiate_auth(payload.username, payload.temp_password)
if initiate_response["ChallengeName"] == "NEW_PASSWORD_REQUIRED":
respond_response = client.respond_to_auth_challenge(
username=payload.username,
session=response["Session"],
session=initiate_response["Session"],
challenge_name="NEW_PASSWORD_REQUIRED",
new_password=payload.password,
)
Expand All @@ -88,9 +88,9 @@ def change_password(
UserAttributes=[{"Name": "email_verified", "Value": "true"}],
)

if "ChallengeName" in response and response["ChallengeName"] == "MFA_SETUP":
if "ChallengeName" in respond_response and respond_response["ChallengeName"] == "MFA_SETUP":
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/cognito-idp/client/associate_software_token.html
associate_response = client.cognito.associate_software_token(Session=response["Session"])
associate_response = client.cognito.associate_software_token(Session=respond_response["Session"])

return serializers.ChangePasswordResponse(
detail=_("Password changed with MFA setup required"),
Expand Down Expand Up @@ -189,8 +189,8 @@ def login(

return serializers.LoginResponse(
user=user,
access_token=mfa_login_response["access_token"],
refresh_token=mfa_login_response["refresh_token"],
access_token=mfa_login_response["AuthenticationResult"]["AccessToken"],
refresh_token=mfa_login_response["AuthenticationResult"]["RefreshToken"],
)


Expand Down Expand Up @@ -240,7 +240,7 @@ def me(
"/users/forgot-password",
)
def forgot_password(
payload: parsers.BasicUser,
payload: parsers.ResetPassword,
client: aws.Client = Depends(dependencies.get_aws_client),
) -> serializers.ResponseBase:
"""
Expand Down
2 changes: 1 addition & 1 deletion docs/_static/routes.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ PUT,/users/setup-mfa,SetupMFA,ResponseBase,SetupMFAInput,IResponse
POST,/users/login,BasicUser,LoginResponse,LoginInput,ILoginResponse
GET,/users/logout,,ResponseBase,,IResponse
GET,/users/me,,UserResponse,,IUserResponse
POST,/users/forgot-password,BasicUser,ResponseBase,ResetPasswordInput,IResponse
POST,/users/forgot-password,ResetPassword,ResponseBase,ResetPasswordInput,IResponse
GET,/users/{user_id},user_id,models.User,id,IUser
GET,/users,"page, page_size, sort_field, sort_order",UserListResponse,PaginationInput,IUsersListResponse
PUT,/users/{user_id},models.User,models.UserWithLender,UpdateUserInput,IUser
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""award id borrower id nonnullable
Revision ID: d9b564fd6859
Revises: eaef8e99aee2
Create Date: 2024-08-23 15:38:20.028504
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "d9b564fd6859"
down_revision = "eaef8e99aee2"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("application", "award_id", existing_type=sa.INTEGER(), nullable=False)
op.alter_column("application", "borrower_id", existing_type=sa.INTEGER(), nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("application", "borrower_id", existing_type=sa.INTEGER(), nullable=True)
op.alter_column("application", "award_id", existing_type=sa.INTEGER(), nullable=True)
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion reportlab_mods.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

width, height = A4
styles = getSampleStyleSheet()
styles = getSampleStyleSheet() # type: ignore[no-untyped-call]
styleN = styles["BodyText"]
styleN.fontName = "GTEestiProDisplay"
styleN.alignment = TA_LEFT
Expand Down
2 changes: 1 addition & 1 deletion tests/routers/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_duplicate_user(client, admin_header, user_payload):


def test_login_invalid_username(client):
response = client.post("/users/login", json={"username": "nonexistent"})
response = client.post("/users/login", json={"username": "nonexistent", "password": "", "temp_password": ""})

assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.json() == {"detail": _("Invalid username or password")}
Expand Down

0 comments on commit 79b69c1

Please sign in to comment.