362 lines
14 KiB
Python
362 lines
14 KiB
Python
import os
|
|
import shutil
|
|
import logging
|
|
from termcolor import colored
|
|
from herotools.pathtools import expand_path
|
|
import psycopg2
|
|
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
|
|
from osis.id import int_to_id
|
|
from psycopg2.extras import DictCursor
|
|
|
|
import sqlite3
|
|
from enum import Enum
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DBCat:
|
|
def __init__(self, path: str, cat: str):
|
|
path = expand_path(path)
|
|
self.path_id = os.path.join(path, "id", cat)
|
|
self.path_human = os.path.join(path, "human", cat)
|
|
self.path = path
|
|
self._init()
|
|
|
|
def _init(self):
|
|
os.makedirs(self.path_id, exist_ok=True)
|
|
os.makedirs(self.path_human, exist_ok=True)
|
|
|
|
def reset(self):
|
|
if os.path.exists(self.path_id):
|
|
shutil.rmtree(self.path_id, ignore_errors=True)
|
|
if os.path.exists(self.path_human):
|
|
shutil.rmtree(self.path_human, ignore_errors=True)
|
|
self._init()
|
|
|
|
def _get_path_id(self, id: str) -> str:
|
|
id1 = id[:2]
|
|
dir_path = os.path.join(self.path_id, id1)
|
|
file_path = os.path.join(dir_path, f"{id}.yaml")
|
|
os.makedirs(dir_path, exist_ok=True)
|
|
return file_path
|
|
|
|
def set(self, id: str, data: str, humanid: str = ""):
|
|
fs_path = self._get_path_id(id=id)
|
|
with open(fs_path, "w") as f:
|
|
f.write(data)
|
|
if humanid != "":
|
|
human_file_path = os.path.join(self.path_human, humanid)
|
|
# Create a symbolic link
|
|
try:
|
|
os.symlink(fs_path, human_file_path)
|
|
except FileExistsError:
|
|
# If the symlink already exists, we can either ignore it or update it
|
|
if not os.path.islink(human_file_path):
|
|
raise # If it's not a symlink, re-raise the exception
|
|
os.remove(human_file_path) # Remove the existing symlink
|
|
os.symlink(fs_path, human_file_path) # Create a new symlink
|
|
return fs_path
|
|
|
|
def get(self, id: str) -> str:
|
|
fs_path = self._get_path_id(id=id)
|
|
with open(fs_path, "r") as f:
|
|
return f.read()
|
|
|
|
def delete(self, id: str, humanid: str = ""):
|
|
fs_path = self._get_path_id(id=id)
|
|
os.remove(fs_path)
|
|
if humanid != "":
|
|
human_file_path = os.path.join(self.path_human, humanid)
|
|
os.remove(human_file_path)
|
|
|
|
class DBType(Enum):
|
|
SQLITE = "sqlite"
|
|
POSTGRESQL = "postgresql"
|
|
|
|
class DBConfig:
|
|
def __init__(
|
|
self,
|
|
db_type: DBType = DBType.POSTGRESQL,
|
|
db_name: str = "main",
|
|
db_login: str = "admin",
|
|
db_passwd: str = "admin",
|
|
db_addr: str = "localhost",
|
|
db_port: int = 5432,
|
|
db_path: str = "/tmp/db"
|
|
):
|
|
self.db_type = db_type
|
|
self.db_name = db_name
|
|
self.db_login = db_login
|
|
self.db_passwd = db_passwd
|
|
self.db_addr = db_addr
|
|
self.db_port = db_port
|
|
self.db_path = expand_path(db_path)
|
|
|
|
def __str__(self):
|
|
return (f"DBConfig(db_name='{self.db_name}', db_login='{self.db_login}', "
|
|
f"db_addr='{self.db_addr}', db_port={self.db_port}, db_path='{self.db_path}')")
|
|
|
|
def __repr__(self):
|
|
return self.__str__()
|
|
|
|
|
|
def url(self) -> str:
|
|
if self.db_type == DBType.POSTGRESQL:
|
|
return f"postgresql://{self.db_login}:{self.db_passwd}@{self.db_addr}:{self.db_port}/{self.db_name}"
|
|
elif self.db_type == DBType.SQLITE:
|
|
return f"sqlite:///{self.db_path}/{self.db_name}.db"
|
|
else:
|
|
raise ValueError(f"Unsupported database type: {self.db_type}")
|
|
|
|
class DB:
|
|
def __init__(self,cfg:DBConfig , path: str, reset: bool = False):
|
|
self.cfg = cfg
|
|
self.path = expand_path(path)
|
|
self.path_id = os.path.join(self.path, "id")
|
|
self.path_human = os.path.join(self.path, "human")
|
|
self.dbcats = dict[str, DBCat]()
|
|
|
|
if reset:
|
|
self.reset()
|
|
else:
|
|
self._init()
|
|
|
|
def reset(self):
|
|
if os.path.exists(self.path_id):
|
|
shutil.rmtree(self.path_id, ignore_errors=True)
|
|
logger.info(colored(f"Removed db dir: {self.path_id}", "red"))
|
|
if os.path.exists(self.path_human):
|
|
shutil.rmtree(self.path_human, ignore_errors=True)
|
|
logger.info(colored(f"Removed db dir: {self.path_human}", "red"))
|
|
if self.cfg.db_type == DBType.POSTGRESQL:
|
|
conn=self.db_connection()
|
|
cur = conn.cursor()
|
|
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.cfg.db_name,))
|
|
exists = cur.fetchone()
|
|
cur.close()
|
|
conn.close()
|
|
if exists:
|
|
# Disconnect from the current database
|
|
# Reconnect to the postgres database to drop the target database
|
|
conn = psycopg2.connect(dbname='postgres', user=self.cfg.db_login, password=self.cfg.db_passwd, host=self.cfg.db_addr)
|
|
conn.autocommit = True
|
|
cur = conn.cursor()
|
|
#need to remove the open connections to be able to remove it
|
|
cur.execute(f"""
|
|
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
|
FROM pg_stat_activity
|
|
WHERE pg_stat_activity.datname = %s
|
|
AND pid <> pg_backend_pid();
|
|
""", (self.cfg.db_name,))
|
|
print(f"Terminated all connections to database '{self.cfg.db_name}'")
|
|
|
|
cur.execute(f"DROP DATABASE {self.cfg.db_name}")
|
|
print(f"Database '{self.cfg.db_name}' dropped successfully.")
|
|
cur.close()
|
|
conn.close()
|
|
|
|
self._init()
|
|
|
|
def _init(self):
|
|
os.makedirs(self.path_human, exist_ok=True)
|
|
os.makedirs(self.path_id, exist_ok=True)
|
|
for key, dbcat in self.dbcats:
|
|
dbcat._init()
|
|
|
|
def dbcat_new(self, cat: str, reset: bool = False) -> DBCat:
|
|
dbc = DBCat(cat=cat, path=self.path)
|
|
self.dbcats[cat] = dbc
|
|
return dbc
|
|
|
|
def dbcat_get(self, cat: str) -> DBCat:
|
|
if cat in self.dbcats:
|
|
return self.dbcats[cat]
|
|
raise Exception(f"can't find dbcat with cat:{cat}")
|
|
|
|
def db_connection(self):
|
|
if self.cfg.db_type == DBType.POSTGRESQL:
|
|
try:
|
|
conn = psycopg2.connect(
|
|
dbname=self.cfg.db_name,
|
|
user=self.cfg.db_login,
|
|
password=self.cfg.db_passwd,
|
|
host=self.cfg.db_addr,
|
|
port=self.cfg.db_port
|
|
)
|
|
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
|
conn.autocommit = True # Set autocommit mode
|
|
except psycopg2.OperationalError as e:
|
|
if f"database \"{self.cfg.db_name}\" does not exist" in str(e):
|
|
# Connect to 'postgres' database to create the new database
|
|
conn = psycopg2.connect(
|
|
dbname='postgres',
|
|
user=self.cfg.db_login,
|
|
password=self.cfg.db_passwd,
|
|
host=self.cfg.db_addr,
|
|
port=self.cfg.db_port
|
|
)
|
|
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
|
cur = conn.cursor()
|
|
cur.execute(f"CREATE DATABASE {self.cfg.db_name}")
|
|
cur.close()
|
|
conn.close()
|
|
|
|
# Now connect to the newly created database
|
|
conn = psycopg2.connect(
|
|
dbname=self.cfg.db_name,
|
|
user=self.cfg.db_login,
|
|
password=self.cfg.db_passwd,
|
|
host=self.cfg.db_addr,
|
|
port=self.cfg.db_port
|
|
)
|
|
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
|
|
print(f"Database '{self.cfg.db_name}' created successfully.")
|
|
else:
|
|
raise e
|
|
elif self.cfg.db_type == DBType.SQLITE:
|
|
db_file = os.path.join(self.cfg.db_path, f"{self.cfg.db_name}.db")
|
|
conn = sqlite3.connect(db_file)
|
|
else:
|
|
raise ValueError(f"Unsupported database type: {self.cfg.db_type}")
|
|
return conn
|
|
|
|
def db_create(self, db_name: str = "", user_name: str = "", user_password: str = ""):
|
|
if self.cfg.db_type == DBType.POSTGRESQL:
|
|
self.db_create_id()
|
|
# Connect to PostgreSQL server
|
|
conn = self.db_connection()
|
|
cur = conn.cursor()
|
|
|
|
if db_name=="":
|
|
db_name=self.cfg.db_name
|
|
|
|
try:
|
|
# Check if the database already exists
|
|
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,))
|
|
exists = cur.fetchone()
|
|
|
|
if not exists:
|
|
# Create the database
|
|
cur.execute(f"CREATE DATABASE {db_name}")
|
|
print(f"Database '{db_name}' created successfully.")
|
|
|
|
if user_name and user_password:
|
|
# Check if user exists
|
|
cur.execute("SELECT 1 FROM pg_roles WHERE rolname = %s", (user_name,))
|
|
user_exists = cur.fetchone()
|
|
|
|
if not user_exists:
|
|
# Create the user
|
|
cur.execute(f"CREATE USER {user_name} WITH PASSWORD %s", (user_password,))
|
|
print(f"User '{user_name}' created successfully.")
|
|
|
|
# Grant privileges on the database to the user
|
|
cur.execute(f"GRANT ALL PRIVILEGES ON DATABASE {db_name} TO {user_name}")
|
|
print(f"Privileges granted to '{user_name}' on '{db_name}'.")
|
|
|
|
except psycopg2.Error as e:
|
|
raise Exception(f"Postgresql error: {e}")
|
|
finally:
|
|
# Close the cursor and connection
|
|
cur.close()
|
|
conn.close()
|
|
|
|
elif self.cfg.db_type == DBType.SQLITE:
|
|
# For SQLite, we just need to create the database file if it doesn't exist
|
|
db_file = os.path.join(self.cfg.db_path, f"{db_name}.db")
|
|
if not os.path.exists(db_file):
|
|
conn = sqlite3.connect(db_file)
|
|
conn.close()
|
|
print(f"SQLite database '{db_name}' created successfully at {db_file}.")
|
|
else:
|
|
print(f"SQLite database '{db_name}' already exists at {db_file}.")
|
|
|
|
if user_name:
|
|
print("Note: SQLite doesn't support user management like PostgreSQL.")
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported database type: {self.cfg.db_type}")
|
|
|
|
|
|
def db_create_id(self):
|
|
with self.db_connection() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
CREATE TABLE IF NOT EXISTS user_id_counters (
|
|
user_id INTEGER PRIMARY KEY,
|
|
last_id_given INTEGER NOT NULL DEFAULT 0
|
|
)
|
|
""")
|
|
conn.commit()
|
|
|
|
|
|
def new_id(self,user_id: int) -> str:
|
|
if not 0 <= user_id <= 50:
|
|
raise ValueError("User ID must be between 0 and 50")
|
|
|
|
max_ids = 60466175
|
|
ids_per_user = max_ids // 51 # We use 51 to ensure we don't exceed the max even for user_id 50
|
|
|
|
with self.db_connection() as conn:
|
|
with conn.cursor(cursor_factory=DictCursor) as cur:
|
|
# Try to get the last_id_given for this user
|
|
cur.execute("SELECT last_id_given FROM user_id_counters WHERE user_id = %s", (user_id,))
|
|
result = cur.fetchone()
|
|
|
|
if result is None:
|
|
# If no record exists for this user, insert a new one
|
|
cur.execute(
|
|
"INSERT INTO user_id_counters (user_id, last_id_given) VALUES (%s, 0) RETURNING last_id_given",
|
|
(user_id,)
|
|
)
|
|
last_id_given = 0
|
|
else:
|
|
last_id_given = result['last_id_given']
|
|
|
|
# Calculate the new ID
|
|
new_id_int = (user_id * ids_per_user) + last_id_given + 1
|
|
|
|
if new_id_int > (user_id + 1) * ids_per_user:
|
|
raise ValueError(f"No more IDs available for user {user_id}")
|
|
|
|
# Update the last_id_given in the database
|
|
cur.execute(
|
|
"UPDATE user_id_counters SET last_id_given = last_id_given + 1 WHERE user_id = %s",
|
|
(user_id,)
|
|
)
|
|
conn.commit()
|
|
|
|
return int_to_id(new_id_int)
|
|
|
|
|
|
|
|
def db_new(
|
|
db_type: DBType = DBType.POSTGRESQL,
|
|
db_name: str = "main",
|
|
db_login: str = "admin",
|
|
db_passwd: str = "admin",
|
|
db_addr: str = "localhost",
|
|
db_port: int = 5432,
|
|
db_path: str = "/tmp/db",
|
|
reset: bool = False,
|
|
):
|
|
# Create a DBConfig object
|
|
config = DBConfig(
|
|
db_type=db_type,
|
|
db_name=db_name,
|
|
db_login=db_login,
|
|
db_passwd=db_passwd,
|
|
db_addr=db_addr,
|
|
db_port=db_port,
|
|
db_path=db_path
|
|
)
|
|
|
|
# Create and return a DB object
|
|
mydb = DB(cfg=config, path=db_path, reset=reset)
|
|
mydb.db_create()
|
|
return mydb
|
|
|
|
|