mirror of
https://github.com/peridotbuild/pdot_common.git
synced 2024-12-04 02:26:26 +00:00
Move common utils from distro-tools
This commit is contained in:
parent
a880481cc7
commit
2d8c654271
85
pdot_common/api.py
Normal file
85
pdot_common/api.py
Normal file
@ -0,0 +1,85 @@
|
||||
import datetime
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Query
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_pagination import Params as FastAPIParams
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
|
||||
class StaticFilesSym(StaticFiles):
|
||||
"subclass StaticFiles middleware to allow symlinks"
|
||||
|
||||
def lookup_path(self, path):
|
||||
for directory in self.all_directories:
|
||||
full_path = os.path.realpath(os.path.join(directory, path))
|
||||
try:
|
||||
stat_result = os.stat(full_path)
|
||||
return full_path, stat_result
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return "", None
|
||||
|
||||
|
||||
class RenderErrorTemplateException(Exception):
|
||||
def __init__(self, msg=None, status_code=404):
|
||||
self.msg = msg
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class Params(FastAPIParams):
|
||||
def get_size(self) -> int:
|
||||
return self.size
|
||||
|
||||
def get_offset(self) -> int:
|
||||
return self.size * (self.page - 1)
|
||||
|
||||
|
||||
class DateTimeParams(BaseModel):
|
||||
before: str = Query(default=None)
|
||||
after: str = Query(default=None)
|
||||
|
||||
before_parsed: datetime.datetime = None
|
||||
after_parsed: datetime.datetime = None
|
||||
|
||||
# noinspection PyMethodParameters
|
||||
@root_validator(pre=True)
|
||||
def __root_validator__(
|
||||
cls, # pylint: disable=no-self-argument
|
||||
value: Any,
|
||||
) -> Any:
|
||||
if value["before"]:
|
||||
before = parse_rfc3339_date(value["before"])
|
||||
if not before:
|
||||
raise RenderErrorTemplateException("Invalid before date", 400)
|
||||
value["before_parsed"] = before
|
||||
|
||||
if value["after"]:
|
||||
after = parse_rfc3339_date(value["after"])
|
||||
if not after:
|
||||
raise RenderErrorTemplateException("Invalid after date", 400)
|
||||
value["after_parsed"] = after
|
||||
|
||||
return value
|
||||
|
||||
def get_before(self) -> datetime.datetime:
|
||||
return self.before_parsed
|
||||
|
||||
def get_after(self) -> datetime.datetime:
|
||||
return self.after_parsed
|
||||
|
||||
|
||||
def parse_rfc3339_date(date: str) -> Optional[datetime.datetime]:
|
||||
if date:
|
||||
try:
|
||||
return datetime.datetime.fromisoformat(date.removesuffix("Z"))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def to_rfc3339_date(date: datetime.datetime) -> str:
|
||||
return date.isoformat("T").replace("+00:00", "") + "Z"
|
37
pdot_common/database.py
Normal file
37
pdot_common/database.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
Database helper methods
|
||||
"""
|
||||
from tortoise import Tortoise, connections
|
||||
|
||||
from pdot_common.info import Info
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""
|
||||
Database connection singleton class
|
||||
"""
|
||||
|
||||
initialized = False
|
||||
|
||||
def __init__(self, initialize=False):
|
||||
if not Database.initialized and not initialize:
|
||||
raise Exception("Database connection not initialized")
|
||||
|
||||
@staticmethod
|
||||
def conn_str():
|
||||
info = Info()
|
||||
|
||||
return f"postgres://{info.db_user()}:{info.db_password()}@{info.db_host()}:{info.db_port()}/{info.db_name()}"
|
||||
|
||||
async def init(self, models):
|
||||
if Database.initialized:
|
||||
return
|
||||
await Tortoise.init(
|
||||
db_url=self.conn_str(), use_tz=True, modules={"models": models}
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
async def shutdown(self):
|
||||
await connections.close_all()
|
||||
self.initialized = False
|
16
pdot_common/env.py
Normal file
16
pdot_common/env.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Environment variables
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
def get_env():
|
||||
return os.environ.get("ENV", "development")
|
||||
|
||||
|
||||
def is_prod():
|
||||
return get_env() == "production"
|
||||
|
||||
|
||||
def is_k8s():
|
||||
return os.environ.get("KUBERNETES", "0") == "1"
|
66
pdot_common/info.py
Normal file
66
pdot_common/info.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Application information
|
||||
"""
|
||||
import os
|
||||
|
||||
from pdot_common.env import get_env, is_k8s
|
||||
|
||||
|
||||
class Info:
|
||||
"""
|
||||
Application information singleton class
|
||||
"""
|
||||
|
||||
_name = None
|
||||
_dbname = None
|
||||
|
||||
def __init__(self, name=None, dbname=None):
|
||||
if not self._name and not name:
|
||||
raise ValueError("Info.name is not set")
|
||||
if self._name and name:
|
||||
raise ValueError("Info.name is already set")
|
||||
if name:
|
||||
Info._name = name
|
||||
Info._dbname = dbname if dbname else name
|
||||
|
||||
self._name = Info._name
|
||||
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def db_name(self):
|
||||
return f"{self._dbname}{get_env()}"
|
||||
|
||||
@staticmethod
|
||||
def db_user():
|
||||
return os.environ.get("DB_USER", "postgres")
|
||||
|
||||
@staticmethod
|
||||
def db_password():
|
||||
return os.environ.get("DB_PASSWORD", "postgres")
|
||||
|
||||
@staticmethod
|
||||
def db_host():
|
||||
return os.environ.get("DB_HOST", "localhost")
|
||||
|
||||
@staticmethod
|
||||
def db_port():
|
||||
return os.environ.get("DB_PORT", "5432")
|
||||
|
||||
@staticmethod
|
||||
def db_ssl_mode():
|
||||
return os.environ.get("DB_SSLMODE", "disable")
|
||||
|
||||
@staticmethod
|
||||
def temporal_host():
|
||||
if is_k8s():
|
||||
return os.environ.get(
|
||||
"TEMPORAL_HOSTPORT",
|
||||
"workflow-temporal-frontend.workflow.svc.cluster.local:7233",
|
||||
)
|
||||
else:
|
||||
return os.environ.get("TEMPORAL_HOSTPORT", "localhost:7233")
|
||||
|
||||
@staticmethod
|
||||
def temporal_namespace():
|
||||
return os.environ.get("TEMPORAL_NAMESPACE", "default")
|
52
pdot_common/logger.py
Normal file
52
pdot_common/logger.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
This module provides a logger class that
|
||||
can be used to log messages to the console.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from pdot_common.env import is_prod
|
||||
from pdot_common.info import Info
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""
|
||||
This class provides a logger that can be used to log messages to the console.
|
||||
"""
|
||||
|
||||
logger = None
|
||||
|
||||
def __init__(self):
|
||||
info = Info()
|
||||
|
||||
if Logger.logger is None:
|
||||
level = logging.INFO
|
||||
if not is_prod():
|
||||
level = logging.DEBUG
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="[%(name)s:%(levelname)s:%(asctime)s] %(message)s",
|
||||
)
|
||||
Logger.logger = logging.getLogger(info.name())
|
||||
|
||||
self.logger = Logger.logger
|
||||
|
||||
def warning(self, msg, *args, **kwargs):
|
||||
self.logger.warning(msg, *args, **kwargs)
|
||||
|
||||
def error(self, msg, *args, **kwargs):
|
||||
self.logger.error(msg, *args, **kwargs)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
self.logger.info(msg, *args, **kwargs)
|
||||
|
||||
def debug(self, msg, *args, **kwargs):
|
||||
self.logger.debug(msg, *args, **kwargs)
|
||||
|
||||
def exception(self, msg, *args, **kwargs):
|
||||
self.logger.exception(msg, *args, **kwargs)
|
||||
|
||||
def critical(self, msg, *args, **kwargs):
|
||||
self.logger.critical(msg, *args, **kwargs)
|
||||
|
||||
def fatal(self, msg, *args, **kwargs):
|
||||
self.logger.fatal(msg, *args, **kwargs)
|
@ -10,11 +10,23 @@ from fastapi.responses import JSONResponse
|
||||
class OIDCConfig:
|
||||
userinfo_endpoint: str
|
||||
required_groups: list[str] = field(default_factory=list)
|
||||
exclude_paths: list[str] = field(default_factory=list)
|
||||
exclude_paths_prefix: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
|
||||
@app.middleware("http")
|
||||
async def verify_oidc_auth(request: Request, call_next):
|
||||
# If path is in exclude_paths, skip OIDC verification
|
||||
if request.url.path in config.exclude_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# If path starts with any of the exclude_paths_prefix,
|
||||
# skip OIDC verification
|
||||
for prefix in config.exclude_paths_prefix:
|
||||
if request.url.path.startswith(prefix):
|
||||
return await call_next(request)
|
||||
|
||||
# First verify that there is an Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
@ -69,16 +81,18 @@ def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
|
||||
if config.required_groups:
|
||||
if not userinfo.get("groups"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
status_code=403,
|
||||
content={"detail": "User does not have any groups"},
|
||||
)
|
||||
|
||||
if not any(
|
||||
group in userinfo["groups"] for group in config.required_groups
|
||||
):
|
||||
required_groups = set(config.required_groups)
|
||||
has_group = any(
|
||||
group in userinfo["groups"] for group in required_groups
|
||||
)
|
||||
if not has_group:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "User does not have required groups"},
|
||||
status_code=403,
|
||||
content={"detail": "User not authorized"},
|
||||
)
|
||||
|
||||
request.state.userinfo = userinfo
|
29
pdot_common/temporal.py
Normal file
29
pdot_common/temporal.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
Temporal helper methods
|
||||
"""
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
from pdot_common.info import Info
|
||||
|
||||
|
||||
class Temporal(object):
|
||||
"""
|
||||
Temporal helper singleton class
|
||||
"""
|
||||
|
||||
client = None
|
||||
|
||||
def __init__(self, initialize=False):
|
||||
if Temporal.client is None and not initialize:
|
||||
raise Exception("Temporal client not initialized")
|
||||
|
||||
self.client = Temporal.client
|
||||
|
||||
async def connect(self):
|
||||
info = Info()
|
||||
Temporal.client = await Client.connect(
|
||||
info.temporal_host(),
|
||||
namespace=info.temporal_namespace(),
|
||||
)
|
||||
self.client = Temporal.client
|
13
pdot_common/testing.py
Normal file
13
pdot_common/testing.py
Normal file
@ -0,0 +1,13 @@
|
||||
class MockResponse:
|
||||
def __init__(self, text, status):
|
||||
self._text = text
|
||||
self.status = status
|
||||
|
||||
async def text(self):
|
||||
return self._text
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
@ -5,9 +5,12 @@ description = "Common Python library for Peridot projects"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"fastapi >= 0.99.0",
|
||||
"authlib >= 1.2.1",
|
||||
"httpx >= 0.24.1",
|
||||
"fastapi == 0.99.0",
|
||||
"fastapi-pagination == 0.12.5",
|
||||
"authlib == 1.2.1",
|
||||
"httpx == 0.24.1",
|
||||
"tortoise-orm[asyncpg] == 0.19.3",
|
||||
"temporalio == 1.2.0"
|
||||
]
|
||||
authors = [
|
||||
{ name = "Mustafa Gezen", email = "mustafa@rockylinux.org" }
|
||||
|
Loading…
Reference in New Issue
Block a user