добавляет dag
This commit is contained in:
84
dags/task_01/tasks.py
Normal file
84
dags/task_01/tasks.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from random import randint
|
||||
|
||||
import pandas as pd
|
||||
from airflow import AirflowException
|
||||
from airflow.models import Variable
|
||||
from task_01.target_db.repositories import TargetDBRepo
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserActionTransfer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_db: TargetDBRepo,
|
||||
):
|
||||
self.target_db = target_db
|
||||
self.log = logging.getLogger(__name__)
|
||||
|
||||
def generate_sample_data(self):
|
||||
"""Генерация тестовых данных и сохранение в CSV файл"""
|
||||
actions = self.get_fake_user_action_data()
|
||||
csv_file_path = self.get_csv_file_path()
|
||||
actions.to_csv(csv_file_path, index=False)
|
||||
|
||||
# Сохраняем путь к файлу в переменную DAG
|
||||
Variable.set("user_action_data_path", csv_file_path)
|
||||
|
||||
logging.info("Sample data generated and saved to: %s", csv_file_path)
|
||||
logging.info(f"Data preview: %s", actions.head())
|
||||
|
||||
def load_data_to_neo4j(self) -> None:
|
||||
csv_file_path = Variable.get("user_action_data_path")
|
||||
|
||||
if not csv_file_path or not os.path.exists(csv_file_path):
|
||||
raise AirflowException("CSV file not found: %s", csv_file_path)
|
||||
|
||||
# Чтение CSV файла
|
||||
user_actions = pd.read_csv(csv_file_path)
|
||||
logging.info("Loaded CSV data with %s rows", len(user_actions))
|
||||
|
||||
self.target_db.save_users(user_actions)
|
||||
total_rows = self.target_db.get_number_of_users()
|
||||
logging.info("Total rows: %s", total_rows)
|
||||
|
||||
# Очистка временного файла
|
||||
os.remove(csv_file_path)
|
||||
logging.info("Temporary CSV file cleaned up")
|
||||
|
||||
def check_neo4j_connection(self):
|
||||
"""Проверка соединения с БД"""
|
||||
try:
|
||||
result = self.target_db.check_connection()
|
||||
log.info(f"Neo4j message: {result}")
|
||||
log.info("Neo4j connection is healthy")
|
||||
except Exception as e:
|
||||
log.error(f"Neo4j connection failed: {e}")
|
||||
raise
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_csv_file_path() -> str:
|
||||
temp_dir = tempfile.gettempdir()
|
||||
return os.path.join(temp_dir, "user_action_data.csv")
|
||||
|
||||
@staticmethod
|
||||
def get_fake_user_action_data() -> pd.DataFrame:
|
||||
actions = ["login", "purchase", "view", "logout", "search"]
|
||||
ids = list(range(1, 10001))
|
||||
action = []
|
||||
timestamp = []
|
||||
for _ in ids:
|
||||
action.append(actions[randint(0, len(actions) - 1)])
|
||||
timestamp.append(datetime.now())
|
||||
sample_data = {
|
||||
"user_id": ids,
|
||||
"action": action,
|
||||
"timestamp": timestamp
|
||||
}
|
||||
return pd.DataFrame(sample_data)
|
||||
Reference in New Issue
Block a user