Files
airflow_neo4j/dags/task_01/tasks.py

101 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import os
import tempfile
from datetime import datetime
from random import randint
import pandas as pd
from airflow import AirflowException
from airflow.models import Variable
from task_01.target_db.repositories import TargetDBRepo
log = logging.getLogger(__name__)
class UserActionTransfer:
"""
Класс для работы с переносом пользовательских действий в Neo4j.
"""
def __init__(
self,
target_db: TargetDBRepo,
):
self.target_db = target_db
self.log = logging.getLogger(__name__)
def generate_sample_data(self):
"""
Генерирует тестовые данные действий пользователей и сохраняет их в CSV-файл.
"""
actions = self.get_fake_user_action_data()
csv_file_path = self.get_csv_file_path()
actions.to_csv(csv_file_path, index=False)
# Сохраняем путь к файлу в переменную DAG
Variable.set("user_action_data_path", csv_file_path)
logging.info("Sample data generated and saved to: %s", csv_file_path)
logging.info(f"Data preview: %s", actions.head())
def load_data_to_neo4j(self) -> None:
"""
Загружает данные из CSV-файла во временной директории в базу данных Neo4j.
"""
csv_file_path = Variable.get("user_action_data_path")
if not csv_file_path or not os.path.exists(csv_file_path):
raise AirflowException("CSV file not found: %s", csv_file_path)
# Чтение CSV файла
user_actions = pd.read_csv(csv_file_path)
logging.info("Loaded CSV data with %s rows", len(user_actions))
self.target_db.save_users(user_actions)
total_rows = self.target_db.get_number_of_users()
logging.info("Total rows: %s", total_rows)
# Очистка временного файла
os.remove(csv_file_path)
logging.info("Temporary CSV file cleaned up")
def check_neo4j_connection(self):
"""
Проверяет соединение с базой данных Neo4j.
"""
try:
result = self.target_db.check_connection()
log.info(f"Neo4j message: {result}")
log.info("Neo4j connection is healthy")
except Exception as e:
log.error(f"Neo4j connection failed: {e}")
raise
return result
@staticmethod
def get_csv_file_path() -> str:
"""
Возвращает путь к CSV-файлу во временной директории для пользовательских действий.
"""
temp_dir = tempfile.gettempdir()
return os.path.join(temp_dir, "user_action_data.csv")
@staticmethod
def get_fake_user_action_data() -> pd.DataFrame:
"""
Генерирует случайные тестовые данные действий пользователей.
"""
actions = ["login", "purchase", "view", "logout", "search"]
ids = list(range(1, 12491))
action = []
timestamp = []
for _ in ids:
action.append(actions[randint(0, len(actions) - 1)])
timestamp.append(datetime.now())
sample_data = {
"user_id": ids,
"action": action,
"timestamp": timestamp
}
return pd.DataFrame(sample_data)