From 2d8c6542718a57887c8d4ab634bac2e561813588 Mon Sep 17 00:00:00 2001 From: Mustafa Gezen Date: Sun, 2 Jul 2023 10:28:48 +0200 Subject: [PATCH] Move common utils from distro-tools --- pdot_common/api.py | 85 +++++++++++++++++++++++ pdot_common/database.py | 37 ++++++++++ pdot_common/env.py | 16 +++++ pdot_common/info.py | 66 ++++++++++++++++++ pdot_common/logger.py | 52 ++++++++++++++ pdot_common/{oidc/__init__.py => oidc.py} | 26 +++++-- pdot_common/temporal.py | 29 ++++++++ pdot_common/testing.py | 13 ++++ pyproject.toml | 9 ++- 9 files changed, 324 insertions(+), 9 deletions(-) create mode 100644 pdot_common/api.py create mode 100644 pdot_common/database.py create mode 100644 pdot_common/env.py create mode 100644 pdot_common/info.py create mode 100644 pdot_common/logger.py rename pdot_common/{oidc/__init__.py => oidc.py} (72%) create mode 100644 pdot_common/temporal.py create mode 100644 pdot_common/testing.py diff --git a/pdot_common/api.py b/pdot_common/api.py new file mode 100644 index 0000000..aeb1181 --- /dev/null +++ b/pdot_common/api.py @@ -0,0 +1,85 @@ +import datetime +import os +from typing import Any, Optional + +from fastapi import Query +from fastapi.staticfiles import StaticFiles +from fastapi_pagination import Params as FastAPIParams + +from pydantic import BaseModel, root_validator + + +class StaticFilesSym(StaticFiles): + "subclass StaticFiles middleware to allow symlinks" + + def lookup_path(self, path): + for directory in self.all_directories: + full_path = os.path.realpath(os.path.join(directory, path)) + try: + stat_result = os.stat(full_path) + return full_path, stat_result + except FileNotFoundError: + pass + return "", None + + +class RenderErrorTemplateException(Exception): + def __init__(self, msg=None, status_code=404): + self.msg = msg + self.status_code = status_code + + +class Params(FastAPIParams): + def get_size(self) -> int: + return self.size + + def get_offset(self) -> int: + return self.size * (self.page - 1) + + +class DateTimeParams(BaseModel): + before: str = Query(default=None) + after: str = Query(default=None) + + before_parsed: datetime.datetime = None + after_parsed: datetime.datetime = None + + # noinspection PyMethodParameters + @root_validator(pre=True) + def __root_validator__( + cls, # pylint: disable=no-self-argument + value: Any, + ) -> Any: + if value["before"]: + before = parse_rfc3339_date(value["before"]) + if not before: + raise RenderErrorTemplateException("Invalid before date", 400) + value["before_parsed"] = before + + if value["after"]: + after = parse_rfc3339_date(value["after"]) + if not after: + raise RenderErrorTemplateException("Invalid after date", 400) + value["after_parsed"] = after + + return value + + def get_before(self) -> datetime.datetime: + return self.before_parsed + + def get_after(self) -> datetime.datetime: + return self.after_parsed + + +def parse_rfc3339_date(date: str) -> Optional[datetime.datetime]: + if date: + try: + return datetime.datetime.fromisoformat(date.removesuffix("Z")) + except ValueError: + return None + + return None + + +def to_rfc3339_date(date: datetime.datetime) -> str: + return date.isoformat("T").replace("+00:00", "") + "Z" diff --git a/pdot_common/database.py b/pdot_common/database.py new file mode 100644 index 0000000..a4f967b --- /dev/null +++ b/pdot_common/database.py @@ -0,0 +1,37 @@ +""" +Database helper methods +""" +from tortoise import Tortoise, connections + +from pdot_common.info import Info + + +class Database(object): + """ + Database connection singleton class + """ + + initialized = False + + def __init__(self, initialize=False): + if not Database.initialized and not initialize: + raise Exception("Database connection not initialized") + + @staticmethod + def conn_str(): + info = Info() + + return f"postgres://{info.db_user()}:{info.db_password()}@{info.db_host()}:{info.db_port()}/{info.db_name()}" + + async def init(self, models): + if Database.initialized: + return + await Tortoise.init( + db_url=self.conn_str(), use_tz=True, modules={"models": models} + ) + + self.initialized = True + + async def shutdown(self): + await connections.close_all() + self.initialized = False diff --git a/pdot_common/env.py b/pdot_common/env.py new file mode 100644 index 0000000..4e3515c --- /dev/null +++ b/pdot_common/env.py @@ -0,0 +1,16 @@ +""" +Environment variables +""" +import os + + +def get_env(): + return os.environ.get("ENV", "development") + + +def is_prod(): + return get_env() == "production" + + +def is_k8s(): + return os.environ.get("KUBERNETES", "0") == "1" diff --git a/pdot_common/info.py b/pdot_common/info.py new file mode 100644 index 0000000..4bdc7b3 --- /dev/null +++ b/pdot_common/info.py @@ -0,0 +1,66 @@ +""" +Application information +""" +import os + +from pdot_common.env import get_env, is_k8s + + +class Info: + """ + Application information singleton class + """ + + _name = None + _dbname = None + + def __init__(self, name=None, dbname=None): + if not self._name and not name: + raise ValueError("Info.name is not set") + if self._name and name: + raise ValueError("Info.name is already set") + if name: + Info._name = name + Info._dbname = dbname if dbname else name + + self._name = Info._name + + def name(self): + return self._name + + def db_name(self): + return f"{self._dbname}{get_env()}" + + @staticmethod + def db_user(): + return os.environ.get("DB_USER", "postgres") + + @staticmethod + def db_password(): + return os.environ.get("DB_PASSWORD", "postgres") + + @staticmethod + def db_host(): + return os.environ.get("DB_HOST", "localhost") + + @staticmethod + def db_port(): + return os.environ.get("DB_PORT", "5432") + + @staticmethod + def db_ssl_mode(): + return os.environ.get("DB_SSLMODE", "disable") + + @staticmethod + def temporal_host(): + if is_k8s(): + return os.environ.get( + "TEMPORAL_HOSTPORT", + "workflow-temporal-frontend.workflow.svc.cluster.local:7233", + ) + else: + return os.environ.get("TEMPORAL_HOSTPORT", "localhost:7233") + + @staticmethod + def temporal_namespace(): + return os.environ.get("TEMPORAL_NAMESPACE", "default") diff --git a/pdot_common/logger.py b/pdot_common/logger.py new file mode 100644 index 0000000..4b2f3f3 --- /dev/null +++ b/pdot_common/logger.py @@ -0,0 +1,52 @@ +""" +This module provides a logger class that +can be used to log messages to the console. +""" +import logging + +from pdot_common.env import is_prod +from pdot_common.info import Info + + +class Logger(object): + """ + This class provides a logger that can be used to log messages to the console. + """ + + logger = None + + def __init__(self): + info = Info() + + if Logger.logger is None: + level = logging.INFO + if not is_prod(): + level = logging.DEBUG + logging.basicConfig( + level=level, + format="[%(name)s:%(levelname)s:%(asctime)s] %(message)s", + ) + Logger.logger = logging.getLogger(info.name()) + + self.logger = Logger.logger + + def warning(self, msg, *args, **kwargs): + self.logger.warning(msg, *args, **kwargs) + + def error(self, msg, *args, **kwargs): + self.logger.error(msg, *args, **kwargs) + + def info(self, msg, *args, **kwargs): + self.logger.info(msg, *args, **kwargs) + + def debug(self, msg, *args, **kwargs): + self.logger.debug(msg, *args, **kwargs) + + def exception(self, msg, *args, **kwargs): + self.logger.exception(msg, *args, **kwargs) + + def critical(self, msg, *args, **kwargs): + self.logger.critical(msg, *args, **kwargs) + + def fatal(self, msg, *args, **kwargs): + self.logger.fatal(msg, *args, **kwargs) diff --git a/pdot_common/oidc/__init__.py b/pdot_common/oidc.py similarity index 72% rename from pdot_common/oidc/__init__.py rename to pdot_common/oidc.py index 9cb0c5b..860ef92 100644 --- a/pdot_common/oidc/__init__.py +++ b/pdot_common/oidc.py @@ -10,11 +10,23 @@ from fastapi.responses import JSONResponse class OIDCConfig: userinfo_endpoint: str required_groups: list[str] = field(default_factory=list) + exclude_paths: list[str] = field(default_factory=list) + exclude_paths_prefix: list[str] = field(default_factory=list) def add_oidc_middleware(app: FastAPI, config: OIDCConfig): @app.middleware("http") async def verify_oidc_auth(request: Request, call_next): + # If path is in exclude_paths, skip OIDC verification + if request.url.path in config.exclude_paths: + return await call_next(request) + + # If path starts with any of the exclude_paths_prefix, + # skip OIDC verification + for prefix in config.exclude_paths_prefix: + if request.url.path.startswith(prefix): + return await call_next(request) + # First verify that there is an Authorization header auth_header = request.headers.get("Authorization") if not auth_header: @@ -69,16 +81,18 @@ def add_oidc_middleware(app: FastAPI, config: OIDCConfig): if config.required_groups: if not userinfo.get("groups"): return JSONResponse( - status_code=401, + status_code=403, content={"detail": "User does not have any groups"}, ) - if not any( - group in userinfo["groups"] for group in config.required_groups - ): + required_groups = set(config.required_groups) + has_group = any( + group in userinfo["groups"] for group in required_groups + ) + if not has_group: return JSONResponse( - status_code=401, - content={"detail": "User does not have required groups"}, + status_code=403, + content={"detail": "User not authorized"}, ) request.state.userinfo = userinfo diff --git a/pdot_common/temporal.py b/pdot_common/temporal.py new file mode 100644 index 0000000..6d06dbd --- /dev/null +++ b/pdot_common/temporal.py @@ -0,0 +1,29 @@ +""" +Temporal helper methods +""" + +from temporalio.client import Client + +from pdot_common.info import Info + + +class Temporal(object): + """ + Temporal helper singleton class + """ + + client = None + + def __init__(self, initialize=False): + if Temporal.client is None and not initialize: + raise Exception("Temporal client not initialized") + + self.client = Temporal.client + + async def connect(self): + info = Info() + Temporal.client = await Client.connect( + info.temporal_host(), + namespace=info.temporal_namespace(), + ) + self.client = Temporal.client diff --git a/pdot_common/testing.py b/pdot_common/testing.py new file mode 100644 index 0000000..8ba10b3 --- /dev/null +++ b/pdot_common/testing.py @@ -0,0 +1,13 @@ +class MockResponse: + def __init__(self, text, status): + self._text = text + self.status = status + + async def text(self): + return self._text + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self diff --git a/pyproject.toml b/pyproject.toml index 7a2156c..dcbf3ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,12 @@ description = "Common Python library for Peridot projects" readme = "README.md" requires-python = ">=3.10" dependencies = [ - "fastapi >= 0.99.0", - "authlib >= 1.2.1", - "httpx >= 0.24.1", + "fastapi == 0.99.0", + "fastapi-pagination == 0.12.5", + "authlib == 1.2.1", + "httpx == 0.24.1", + "tortoise-orm[asyncpg] == 0.19.3", + "temporalio == 1.2.0" ] authors = [ { name = "Mustafa Gezen", email = "mustafa@rockylinux.org" }