Files
airflow_neo4j/dags/task_01/target_db/repositories.py
2025-11-04 00:17:47 +03:00

42 lines
1.2 KiB
Python

import logging
from neo4j import Driver
from pandas import DataFrame
BATCH_SIZE = 1000
class TargetDBRepo:
def __init__(self, driver: Driver):
self.driver = driver
self.log = logging.getLogger(__name__)
def save_users(self, users: DataFrame) -> None:
query = """
UNWIND $rows AS row
CREATE (u:User {
user_id: row.user_id,
action: row.action,
timestamp: row.timestamp
})
"""
with self.driver.session() as session:
for i in range(0, len(users), BATCH_SIZE):
batch = users.iloc[i:i + BATCH_SIZE]
records = batch.to_dict(orient="records")
session.run(query, {"rows": records})
self.log.info("rows saved %s", i + BATCH_SIZE)
def get_number_of_users(self) -> int:
with self.driver.session() as session:
result = session.run(
"MATCH (u:User) RETURN count(u) as user_count")
return result.single()["user_count"]
def check_connection(self):
with self.driver.session() as session:
result = session.run('RETURN "Connection successful" AS message')
return result.single()["message"]