mirror of
https://github.com/peridotbuild/pdot_common.git
synced 2024-12-04 18:46:26 +00:00
Move common utils from distro-tools
This commit is contained in:
parent
a880481cc7
commit
2d8c654271
85
pdot_common/api.py
Normal file
85
pdot_common/api.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from fastapi import Query
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi_pagination import Params as FastAPIParams
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
|
||||||
|
class StaticFilesSym(StaticFiles):
|
||||||
|
"subclass StaticFiles middleware to allow symlinks"
|
||||||
|
|
||||||
|
def lookup_path(self, path):
|
||||||
|
for directory in self.all_directories:
|
||||||
|
full_path = os.path.realpath(os.path.join(directory, path))
|
||||||
|
try:
|
||||||
|
stat_result = os.stat(full_path)
|
||||||
|
return full_path, stat_result
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
|
||||||
|
class RenderErrorTemplateException(Exception):
|
||||||
|
def __init__(self, msg=None, status_code=404):
|
||||||
|
self.msg = msg
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
|
||||||
|
class Params(FastAPIParams):
|
||||||
|
def get_size(self) -> int:
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_offset(self) -> int:
|
||||||
|
return self.size * (self.page - 1)
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeParams(BaseModel):
|
||||||
|
before: str = Query(default=None)
|
||||||
|
after: str = Query(default=None)
|
||||||
|
|
||||||
|
before_parsed: datetime.datetime = None
|
||||||
|
after_parsed: datetime.datetime = None
|
||||||
|
|
||||||
|
# noinspection PyMethodParameters
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def __root_validator__(
|
||||||
|
cls, # pylint: disable=no-self-argument
|
||||||
|
value: Any,
|
||||||
|
) -> Any:
|
||||||
|
if value["before"]:
|
||||||
|
before = parse_rfc3339_date(value["before"])
|
||||||
|
if not before:
|
||||||
|
raise RenderErrorTemplateException("Invalid before date", 400)
|
||||||
|
value["before_parsed"] = before
|
||||||
|
|
||||||
|
if value["after"]:
|
||||||
|
after = parse_rfc3339_date(value["after"])
|
||||||
|
if not after:
|
||||||
|
raise RenderErrorTemplateException("Invalid after date", 400)
|
||||||
|
value["after_parsed"] = after
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_before(self) -> datetime.datetime:
|
||||||
|
return self.before_parsed
|
||||||
|
|
||||||
|
def get_after(self) -> datetime.datetime:
|
||||||
|
return self.after_parsed
|
||||||
|
|
||||||
|
|
||||||
|
def parse_rfc3339_date(date: str) -> Optional[datetime.datetime]:
|
||||||
|
if date:
|
||||||
|
try:
|
||||||
|
return datetime.datetime.fromisoformat(date.removesuffix("Z"))
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def to_rfc3339_date(date: datetime.datetime) -> str:
|
||||||
|
return date.isoformat("T").replace("+00:00", "") + "Z"
|
37
pdot_common/database.py
Normal file
37
pdot_common/database.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
Database helper methods
|
||||||
|
"""
|
||||||
|
from tortoise import Tortoise, connections
|
||||||
|
|
||||||
|
from pdot_common.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
class Database(object):
|
||||||
|
"""
|
||||||
|
Database connection singleton class
|
||||||
|
"""
|
||||||
|
|
||||||
|
initialized = False
|
||||||
|
|
||||||
|
def __init__(self, initialize=False):
|
||||||
|
if not Database.initialized and not initialize:
|
||||||
|
raise Exception("Database connection not initialized")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def conn_str():
|
||||||
|
info = Info()
|
||||||
|
|
||||||
|
return f"postgres://{info.db_user()}:{info.db_password()}@{info.db_host()}:{info.db_port()}/{info.db_name()}"
|
||||||
|
|
||||||
|
async def init(self, models):
|
||||||
|
if Database.initialized:
|
||||||
|
return
|
||||||
|
await Tortoise.init(
|
||||||
|
db_url=self.conn_str(), use_tz=True, modules={"models": models}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await connections.close_all()
|
||||||
|
self.initialized = False
|
16
pdot_common/env.py
Normal file
16
pdot_common/env.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Environment variables
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_env():
|
||||||
|
return os.environ.get("ENV", "development")
|
||||||
|
|
||||||
|
|
||||||
|
def is_prod():
|
||||||
|
return get_env() == "production"
|
||||||
|
|
||||||
|
|
||||||
|
def is_k8s():
|
||||||
|
return os.environ.get("KUBERNETES", "0") == "1"
|
66
pdot_common/info.py
Normal file
66
pdot_common/info.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Application information
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pdot_common.env import get_env, is_k8s
|
||||||
|
|
||||||
|
|
||||||
|
class Info:
|
||||||
|
"""
|
||||||
|
Application information singleton class
|
||||||
|
"""
|
||||||
|
|
||||||
|
_name = None
|
||||||
|
_dbname = None
|
||||||
|
|
||||||
|
def __init__(self, name=None, dbname=None):
|
||||||
|
if not self._name and not name:
|
||||||
|
raise ValueError("Info.name is not set")
|
||||||
|
if self._name and name:
|
||||||
|
raise ValueError("Info.name is already set")
|
||||||
|
if name:
|
||||||
|
Info._name = name
|
||||||
|
Info._dbname = dbname if dbname else name
|
||||||
|
|
||||||
|
self._name = Info._name
|
||||||
|
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def db_name(self):
|
||||||
|
return f"{self._dbname}{get_env()}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def db_user():
|
||||||
|
return os.environ.get("DB_USER", "postgres")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def db_password():
|
||||||
|
return os.environ.get("DB_PASSWORD", "postgres")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def db_host():
|
||||||
|
return os.environ.get("DB_HOST", "localhost")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def db_port():
|
||||||
|
return os.environ.get("DB_PORT", "5432")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def db_ssl_mode():
|
||||||
|
return os.environ.get("DB_SSLMODE", "disable")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def temporal_host():
|
||||||
|
if is_k8s():
|
||||||
|
return os.environ.get(
|
||||||
|
"TEMPORAL_HOSTPORT",
|
||||||
|
"workflow-temporal-frontend.workflow.svc.cluster.local:7233",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return os.environ.get("TEMPORAL_HOSTPORT", "localhost:7233")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def temporal_namespace():
|
||||||
|
return os.environ.get("TEMPORAL_NAMESPACE", "default")
|
52
pdot_common/logger.py
Normal file
52
pdot_common/logger.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
This module provides a logger class that
|
||||||
|
can be used to log messages to the console.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pdot_common.env import is_prod
|
||||||
|
from pdot_common.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(object):
|
||||||
|
"""
|
||||||
|
This class provides a logger that can be used to log messages to the console.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
info = Info()
|
||||||
|
|
||||||
|
if Logger.logger is None:
|
||||||
|
level = logging.INFO
|
||||||
|
if not is_prod():
|
||||||
|
level = logging.DEBUG
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="[%(name)s:%(levelname)s:%(asctime)s] %(message)s",
|
||||||
|
)
|
||||||
|
Logger.logger = logging.getLogger(info.name())
|
||||||
|
|
||||||
|
self.logger = Logger.logger
|
||||||
|
|
||||||
|
def warning(self, msg, *args, **kwargs):
|
||||||
|
self.logger.warning(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def error(self, msg, *args, **kwargs):
|
||||||
|
self.logger.error(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def info(self, msg, *args, **kwargs):
|
||||||
|
self.logger.info(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def debug(self, msg, *args, **kwargs):
|
||||||
|
self.logger.debug(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def exception(self, msg, *args, **kwargs):
|
||||||
|
self.logger.exception(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def critical(self, msg, *args, **kwargs):
|
||||||
|
self.logger.critical(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
def fatal(self, msg, *args, **kwargs):
|
||||||
|
self.logger.fatal(msg, *args, **kwargs)
|
@ -10,11 +10,23 @@ from fastapi.responses import JSONResponse
|
|||||||
class OIDCConfig:
|
class OIDCConfig:
|
||||||
userinfo_endpoint: str
|
userinfo_endpoint: str
|
||||||
required_groups: list[str] = field(default_factory=list)
|
required_groups: list[str] = field(default_factory=list)
|
||||||
|
exclude_paths: list[str] = field(default_factory=list)
|
||||||
|
exclude_paths_prefix: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
|
def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def verify_oidc_auth(request: Request, call_next):
|
async def verify_oidc_auth(request: Request, call_next):
|
||||||
|
# If path is in exclude_paths, skip OIDC verification
|
||||||
|
if request.url.path in config.exclude_paths:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# If path starts with any of the exclude_paths_prefix,
|
||||||
|
# skip OIDC verification
|
||||||
|
for prefix in config.exclude_paths_prefix:
|
||||||
|
if request.url.path.startswith(prefix):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
# First verify that there is an Authorization header
|
# First verify that there is an Authorization header
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
if not auth_header:
|
if not auth_header:
|
||||||
@ -69,16 +81,18 @@ def add_oidc_middleware(app: FastAPI, config: OIDCConfig):
|
|||||||
if config.required_groups:
|
if config.required_groups:
|
||||||
if not userinfo.get("groups"):
|
if not userinfo.get("groups"):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=403,
|
||||||
content={"detail": "User does not have any groups"},
|
content={"detail": "User does not have any groups"},
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(
|
required_groups = set(config.required_groups)
|
||||||
group in userinfo["groups"] for group in config.required_groups
|
has_group = any(
|
||||||
):
|
group in userinfo["groups"] for group in required_groups
|
||||||
|
)
|
||||||
|
if not has_group:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=403,
|
||||||
content={"detail": "User does not have required groups"},
|
content={"detail": "User not authorized"},
|
||||||
)
|
)
|
||||||
|
|
||||||
request.state.userinfo = userinfo
|
request.state.userinfo = userinfo
|
29
pdot_common/temporal.py
Normal file
29
pdot_common/temporal.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Temporal helper methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
from temporalio.client import Client
|
||||||
|
|
||||||
|
from pdot_common.info import Info
|
||||||
|
|
||||||
|
|
||||||
|
class Temporal(object):
|
||||||
|
"""
|
||||||
|
Temporal helper singleton class
|
||||||
|
"""
|
||||||
|
|
||||||
|
client = None
|
||||||
|
|
||||||
|
def __init__(self, initialize=False):
|
||||||
|
if Temporal.client is None and not initialize:
|
||||||
|
raise Exception("Temporal client not initialized")
|
||||||
|
|
||||||
|
self.client = Temporal.client
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
info = Info()
|
||||||
|
Temporal.client = await Client.connect(
|
||||||
|
info.temporal_host(),
|
||||||
|
namespace=info.temporal_namespace(),
|
||||||
|
)
|
||||||
|
self.client = Temporal.client
|
13
pdot_common/testing.py
Normal file
13
pdot_common/testing.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
class MockResponse:
|
||||||
|
def __init__(self, text, status):
|
||||||
|
self._text = text
|
||||||
|
self.status = status
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return self._text
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
@ -5,9 +5,12 @@ description = "Common Python library for Peridot projects"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"fastapi >= 0.99.0",
|
"fastapi == 0.99.0",
|
||||||
"authlib >= 1.2.1",
|
"fastapi-pagination == 0.12.5",
|
||||||
"httpx >= 0.24.1",
|
"authlib == 1.2.1",
|
||||||
|
"httpx == 0.24.1",
|
||||||
|
"tortoise-orm[asyncpg] == 0.19.3",
|
||||||
|
"temporalio == 1.2.0"
|
||||||
]
|
]
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Mustafa Gezen", email = "mustafa@rockylinux.org" }
|
{ name = "Mustafa Gezen", email = "mustafa@rockylinux.org" }
|
||||||
|
Loading…
Reference in New Issue
Block a user