Move common utils from distro-tools

This commit is contained in:
Mustafa Gezen 2023-07-02 10:28:48 +02:00
parent a880481cc7
commit 2d8c654271
9 changed files with 324 additions and 9 deletions

85
pdot_common/api.py Normal file
View File

@ -0,0 +1,85 @@
import datetime
import os
from typing import Any, Optional
from fastapi import Query
from fastapi.staticfiles import StaticFiles
from fastapi_pagination import Params as FastAPIParams
from pydantic import BaseModel, root_validator
class StaticFilesSym(StaticFiles):
"subclass StaticFiles middleware to allow symlinks"
def lookup_path(self, path):
for directory in self.all_directories:
full_path = os.path.realpath(os.path.join(directory, path))
try:
stat_result = os.stat(full_path)
return full_path, stat_result
except FileNotFoundError:
pass
return "", None
class RenderErrorTemplateException(Exception):
def __init__(self, msg=None, status_code=404):
self.msg = msg
self.status_code = status_code
class Params(FastAPIParams):
def get_size(self) -> int:
return self.size
def get_offset(self) -> int:
return self.size * (self.page - 1)
class DateTimeParams(BaseModel):
before: str = Query(default=None)
after: str = Query(default=None)
before_parsed: datetime.datetime = None
after_parsed: datetime.datetime = None
# noinspection PyMethodParameters
@root_validator(pre=True)
def __root_validator__(
cls, # pylint: disable=no-self-argument
value: Any,
) -> Any:
if value["before"]:
before = parse_rfc3339_date(value["before"])
if not before:
raise RenderErrorTemplateException("Invalid before date", 400)
value["before_parsed"] = before
if value["after"]:
after = parse_rfc3339_date(value["after"])
if not after:
raise RenderErrorTemplateException("Invalid after date", 400)
value["after_parsed"] = after
return value
def get_before(self) -> datetime.datetime:
return self.before_parsed
def get_after(self) -> datetime.datetime:
return self.after_parsed
def parse_rfc3339_date(date: str) -> Optional[datetime.datetime]:
if date:
try:
return datetime.datetime.fromisoformat(date.removesuffix("Z"))
except ValueError:
return None
return None
def to_rfc3339_date(date: datetime.datetime) -> str:
return date.isoformat("T").replace("+00:00", "") + "Z"

37
pdot_common/database.py Normal file
View File

@ -0,0 +1,37 @@
"""
Database helper methods
"""
from tortoise import Tortoise, connections
from pdot_common.info import Info
class Database(object):
"""
Database connection singleton class
"""
initialized = False
def __init__(self, initialize=False):
if not Database.initialized and not initialize:
raise Exception("Database connection not initialized")
@staticmethod
def conn_str():
info = Info()
return f"postgres://{info.db_user()}:{info.db_password()}@{info.db_host()}:{info.db_port()}/{info.db_name()}"
async def init(self, models):
if Database.initialized:
return
await Tortoise.init(
db_url=self.conn_str(), use_tz=True, modules={"models": models}
)
self.initialized = True
async def shutdown(self):
await connections.close_all()
self.initialized = False

16
pdot_common/env.py Normal file
View File

@ -0,0 +1,16 @@
"""
Environment variables
"""
import os
def get_env():
return os.environ.get("ENV", "development")
def is_prod():
return get_env() == "production"
def is_k8s():
return os.environ.get("KUBERNETES", "0") == "1"

66
pdot_common/info.py Normal file
View File

@ -0,0 +1,66 @@
"""
Application information
"""
import os
from pdot_common.env import get_env, is_k8s
class Info:
"""
Application information singleton class
"""
_name = None
_dbname = None
def __init__(self, name=None, dbname=None):
if not self._name and not name:
raise ValueError("Info.name is not set")
if self._name and name:
raise ValueError("Info.name is already set")
if name:
Info._name = name
Info._dbname = dbname if dbname else name
self._name = Info._name
def name(self):
return self._name
def db_name(self):
return f"{self._dbname}{get_env()}"
@staticmethod
def db_user():
return os.environ.get("DB_USER", "postgres")
@staticmethod
def db_password():
return os.environ.get("DB_PASSWORD", "postgres")
@staticmethod
def db_host():
return os.environ.get("DB_HOST", "localhost")
@staticmethod
def db_port():
return os.environ.get("DB_PORT", "5432")
@staticmethod
def db_ssl_mode():
return os.environ.get("DB_SSLMODE", "disable")
@staticmethod
def temporal_host():
if is_k8s():
return os.environ.get(
"TEMPORAL_HOSTPORT",
"workflow-temporal-frontend.workflow.svc.cluster.local:7233",
)
else:
return os.environ.get("TEMPORAL_HOSTPORT", "localhost:7233")
@staticmethod
def temporal_namespace():
return os.environ.get("TEMPORAL_NAMESPACE", "default")

52
pdot_common/logger.py Normal file
View File

@ -0,0 +1,52 @@
"""
This module provides a logger class that
can be used to log messages to the console.
"""
import logging
from pdot_common.env import is_prod
from pdot_common.info import Info
class Logger(object):
"""
This class provides a logger that can be used to log messages to the console.
"""
logger = None
def __init__(self):
info = Info()
if Logger.logger is None:
level = logging.INFO
if not is_prod():
level = logging.DEBUG
logging.basicConfig(
level=level,
format="[%(name)s:%(levelname)s:%(asctime)s] %(message)s",
)
Logger.logger = logging.getLogger(info.name())
self.logger = Logger.logger
def warning(self, msg, *args, **kwargs):
self.logger.warning(msg, *args, **kwargs)
def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)
def info(self, msg, *args, **kwargs):
self.logger.info(msg, *args, **kwargs)
def debug(self, msg, *args, **kwargs):
self.logger.debug(msg, *args, **kwargs)
def exception(self, msg, *args, **kwargs):
self.logger.exception(msg, *args, **kwargs)
def critical(self, msg, *args, **kwargs):
self.logger.critical(msg, *args, **kwargs)
def fatal(self, msg, *args, **kwargs):
self.logger.fatal(msg, *args, **kwargs)

View File

@ -10,11 +10,23 @@ from fastapi.responses import JSONResponse
class OIDCConfig: class OIDCConfig:
userinfo_endpoint: str userinfo_endpoint: str
required_groups: list[str] = field(default_factory=list) required_groups: list[str] = field(default_factory=list)
exclude_paths: list[str] = field(default_factory=list)
exclude_paths_prefix: list[str] = field(default_factory=list)
def add_oidc_middleware(app: FastAPI, config: OIDCConfig): def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
@app.middleware("http") @app.middleware("http")
async def verify_oidc_auth(request: Request, call_next): async def verify_oidc_auth(request: Request, call_next):
# If path is in exclude_paths, skip OIDC verification
if request.url.path in config.exclude_paths:
return await call_next(request)
# If path starts with any of the exclude_paths_prefix,
# skip OIDC verification
for prefix in config.exclude_paths_prefix:
if request.url.path.startswith(prefix):
return await call_next(request)
# First verify that there is an Authorization header # First verify that there is an Authorization header
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if not auth_header: if not auth_header:
@ -69,16 +81,18 @@ def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
if config.required_groups: if config.required_groups:
if not userinfo.get("groups"): if not userinfo.get("groups"):
return JSONResponse( return JSONResponse(
status_code=401, status_code=403,
content={"detail": "User does not have any groups"}, content={"detail": "User does not have any groups"},
) )
if not any( required_groups = set(config.required_groups)
group in userinfo["groups"] for group in config.required_groups has_group = any(
): group in userinfo["groups"] for group in required_groups
)
if not has_group:
return JSONResponse( return JSONResponse(
status_code=401, status_code=403,
content={"detail": "User does not have required groups"}, content={"detail": "User not authorized"},
) )
request.state.userinfo = userinfo request.state.userinfo = userinfo

29
pdot_common/temporal.py Normal file
View File

@ -0,0 +1,29 @@
"""
Temporal helper methods
"""
from temporalio.client import Client
from pdot_common.info import Info
class Temporal(object):
"""
Temporal helper singleton class
"""
client = None
def __init__(self, initialize=False):
if Temporal.client is None and not initialize:
raise Exception("Temporal client not initialized")
self.client = Temporal.client
async def connect(self):
info = Info()
Temporal.client = await Client.connect(
info.temporal_host(),
namespace=info.temporal_namespace(),
)
self.client = Temporal.client

13
pdot_common/testing.py Normal file
View File

@ -0,0 +1,13 @@
class MockResponse:
def __init__(self, text, status):
self._text = text
self.status = status
async def text(self):
return self._text
async def __aexit__(self, exc_type, exc, tb):
pass
async def __aenter__(self):
return self

View File

@ -5,9 +5,12 @@ description = "Common Python library for Peridot projects"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"fastapi >= 0.99.0", "fastapi == 0.99.0",
"authlib >= 1.2.1", "fastapi-pagination == 0.12.5",
"httpx >= 0.24.1", "authlib == 1.2.1",
"httpx == 0.24.1",
"tortoise-orm[asyncpg] == 0.19.3",
"temporalio == 1.2.0"
] ]
authors = [ authors = [
{ name = "Mustafa Gezen", email = "mustafa@rockylinux.org" } { name = "Mustafa Gezen", email = "mustafa@rockylinux.org" }