Files
airflow_neo4j/dags/task_01/tasks.py
2025-11-04 00:17:47 +03:00

85 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import tempfile
from datetime import datetime
from random import randint
import pandas as pd
from airflow import AirflowException
from airflow.models import Variable
from task_01.target_db.repositories import TargetDBRepo
log = logging.getLogger(__name__)
class UserActionTransfer:
def __init__(
self,
target_db: TargetDBRepo,
):
self.target_db = target_db
self.log = logging.getLogger(__name__)
def generate_sample_data(self):
"""Генерация тестовых данных и сохранение в CSV файл"""
actions = self.get_fake_user_action_data()
csv_file_path = self.get_csv_file_path()
actions.to_csv(csv_file_path, index=False)
# Сохраняем путь к файлу в переменную DAG
Variable.set("user_action_data_path", csv_file_path)
logging.info("Sample data generated and saved to: %s", csv_file_path)
logging.info(f"Data preview: %s", actions.head())
def load_data_to_neo4j(self) -> None:
csv_file_path = Variable.get("user_action_data_path")
if not csv_file_path or not os.path.exists(csv_file_path):
raise AirflowException("CSV file not found: %s", csv_file_path)
# Чтение CSV файла
user_actions = pd.read_csv(csv_file_path)
logging.info("Loaded CSV data with %s rows", len(user_actions))
self.target_db.save_users(user_actions)
total_rows = self.target_db.get_number_of_users()
logging.info("Total rows: %s", total_rows)
# Очистка временного файла
os.remove(csv_file_path)
logging.info("Temporary CSV file cleaned up")
def check_neo4j_connection(self):
"""Проверка соединения с БД"""
try:
result = self.target_db.check_connection()
log.info(f"Neo4j message: {result}")
log.info("Neo4j connection is healthy")
except Exception as e:
log.error(f"Neo4j connection failed: {e}")
raise
return result
@staticmethod
def get_csv_file_path() -> str:
temp_dir = tempfile.gettempdir()
return os.path.join(temp_dir, "user_action_data.csv")
@staticmethod
def get_fake_user_action_data() -> pd.DataFrame:
actions = ["login", "purchase", "view", "logout", "search"]
ids = list(range(1, 10001))
action = []
timestamp = []
for _ in ids:
action.append(actions[randint(0, len(actions) - 1)])
timestamp.append(datetime.now())
sample_data = {
"user_id": ids,
"action": action,
"timestamp": timestamp
}
return pd.DataFrame(sample_data)