diff --git a/api/Dockerfile.dev b/api/Dockerfile.dev new file mode 100644 index 0000000..1c439b9 --- /dev/null +++ b/api/Dockerfile.dev @@ -0,0 +1,52 @@ +ARG PYTHON_VERSION_CODE=3.10 +ARG ENVIRONMENT="dev" +# ENVIRONMENT: dev or prod, refer to project.optional-dependencies in pyproject.toml + +FROM python:${PYTHON_VERSION_CODE}-bookworm as builder +ARG PYTHON_VERSION_CODE +ARG ENVIRONMENT + +WORKDIR /app +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +COPY api/pyproject.toml api/README.md ./ +COPY api/birdxplorer_api/__init__.py ./birdxplorer_api/ + +RUN if [ "${ENVIRONMENT}" = "prod" ]; then \ + apt-get update && apt-get install -y --no-install-recommends \ + postgresql-client-15 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/*; \ + fi + +RUN python -m pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -e ".[${ENVIRONMENT}]" + +COPY ../common ./common +RUN if [ "${ENVIRONMENT}" = "dev" ]; then \ + pip install -e ./common; \ + fi + +FROM python:${PYTHON_VERSION_CODE}-slim-bookworm as runner +ARG PYTHON_VERSION_CODE +ARG ENVIRONMENT + +WORKDIR /app + +RUN if [ "${ENVIRONMENT}" = "prod" ]; then \ + apt-get update && apt-get install -y --no-install-recommends \ + libpq5 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/*; \ + fi + +RUN groupadd -r app && useradd -r -g app app +RUN chown -R app:app /app +USER app + +COPY --from=builder /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages +COPY --chown=app:app api ./ +COPY ../common ./common + +ENTRYPOINT ["python", "-m", "uvicorn", "birdxplorer_api.main:app", "--host", "0.0.0.0"] diff --git a/api/birdxplorer_api/app.py b/api/birdxplorer_api/app.py index 38a6e86..6a91989 100644 --- a/api/birdxplorer_api/app.py +++ b/api/birdxplorer_api/app.py @@ -4,6 +4,7 @@ from urllib.parse import urlencode as encode_query_string from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from pydantic.alias_generators import to_snake from starlette.types import ASGIApp, Receive, Scope, Send @@ -41,6 +42,7 @@ def gen_app(settings: GlobalSettings) -> FastAPI: _ = get_logger(level=settings.logger_settings.level) storage = gen_storage(settings=settings) app = FastAPI() + app.add_middleware(CORSMiddleware, **settings.cors_settings.model_dump()) app.add_middleware(QueryStringFlatteningMiddleware) app.include_router(gen_system_router(), prefix="/api/v1/system") app.include_router(gen_data_router(storage=storage), prefix="/api/v1/data") diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 2398df8..59ac30d 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -14,7 +14,6 @@ PostId, Topic, TopicId, - TweetId, TwitterTimestamp, UserEnrollment, ) @@ -73,7 +72,7 @@ def get_notes( created_at_from: Union[None, TwitterTimestamp] = Query(default=None), created_at_to: Union[None, TwitterTimestamp] = Query(default=None), topic_ids: Union[List[TopicId], None] = Query(default=None), - post_ids: Union[List[TweetId], None] = Query(default=None), + post_ids: Union[List[PostId], None] = Query(default=None), language: Union[LanguageIdentifier, None] = Query(default=None), ) -> NoteListResponse: return NoteListResponse( @@ -92,11 +91,14 @@ def get_notes( @router.get("/posts", response_model=PostListResponse) def get_posts( post_id: Union[List[PostId], None] = Query(default=None), + note_id: Union[List[NoteId], None] = Query(default=None), created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None), created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None), ) -> PostListResponse: if post_id is not None: return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id))) + if note_id is not None: + return PostListResponse(data=list(storage.get_posts_by_note_ids(note_ids=note_id))) if created_at_start is not None: if created_at_end is not None: return PostListResponse( diff --git a/api/pyproject.toml b/api/pyproject.toml index f600636..af6f324 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -25,7 +25,6 @@ classifiers = [ ] dependencies = [ - "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common", "fastapi", "python-dateutil", "pydantic", @@ -37,10 +36,10 @@ dependencies = [ Source = "https://github.com/codeforjapan/BirdXplorer" [tool.setuptools] -packages=["birdxplorer"] +packages=["birdxplorer_api"] [tool.setuptools.package-data] -birdxplorer = ["py.typed"] +birdxplorer_api = ["py.typed"] [project.optional-dependencies] dev=[ @@ -62,6 +61,7 @@ dev=[ "httpx", ] prod=[ + "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common", "psycopg2", "gunicorn", ] @@ -106,6 +106,7 @@ legacy_tox_ini = """ VIRTUALENV_PIP = 24.0 deps = -e .[dev] + -e ../common commands = black birdxplorer_api tests isort birdxplorer_api tests diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 093d718..c6b5bee 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -21,12 +21,15 @@ PostId, Topic, TopicId, - TweetId, TwitterTimestamp, UserEnrollment, XUser, ) -from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings +from birdxplorer_common.settings import ( + CORSSettings, + GlobalSettings, + PostgresStorageSettings, +) from birdxplorer_common.storage import Storage @@ -223,7 +226,7 @@ def _get_notes( created_at_from: Union[None, TwitterTimestamp] = None, created_at_to: Union[None, TwitterTimestamp] = None, topic_ids: Union[List[TopicId], None] = None, - post_ids: Union[List[TweetId], None] = None, + post_ids: Union[List[PostId], None] = None, language: Union[LanguageIdentifier, None] = None, ) -> Generator[Note, None, None]: for note in note_samples: @@ -258,6 +261,15 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: mock.get_posts_by_ids.side_effect = _get_posts_by_ids + def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]: + for post in post_samples: + for note in note_samples: + if note.note_id in note_ids and post.post_id == note.post_id: + yield post + break + + mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids + def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: for post in post_samples: if start <= post.created_at < end: @@ -294,6 +306,19 @@ def load_dotenv_fixture() -> None: load_dotenv() +@fixture +def cors_settings_factory(load_dotenv_fixture: None) -> Type[ModelFactory[CORSSettings]]: + class CORSSettingsFactory(ModelFactory[CORSSettings]): + __model__ = CORSSettings + + allow_credentials = True + allow_methods = ["*"] + allow_headers = ["*"] + allow_origins = ["*"] + + return CORSSettingsFactory + + @fixture def postgres_storage_settings_factory( load_dotenv_fixture: None, @@ -312,11 +337,13 @@ class PostgresStorageSettingsFactory(ModelFactory[PostgresStorageSettings]): @fixture def global_settings_factory( + cors_settings_factory: Type[ModelFactory[CORSSettings]], postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], ) -> Type[ModelFactory[GlobalSettings]]: class GlobalSettingsFactory(ModelFactory[GlobalSettings]): __model__ = GlobalSettings + cors_settings = cors_settings_factory.build() storage_settings = postgres_storage_settings_factory.build() return GlobalSettingsFactory @@ -325,10 +352,12 @@ class GlobalSettingsFactory(ModelFactory[GlobalSettings]): @fixture def settings_for_test( global_settings_factory: Type[ModelFactory[GlobalSettings]], + cors_settings_factory: Type[ModelFactory[CORSSettings]], postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], ) -> Generator[GlobalSettings, None, None]: yield global_settings_factory.build( - storage_settings=postgres_storage_settings_factory.build(database=TEST_DATABASE_NAME) + cors_settings=cors_settings_factory.build(allow_origins=["http://allowed.example.com"]), + storage_settings=postgres_storage_settings_factory.build(database=TEST_DATABASE_NAME), ) diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 8bf271c..59fef49 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -39,6 +39,13 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos } +def test_posts_get_has_note_id_filter(client: TestClient, post_samples: List[Post], note_samples: List[Note]) -> None: + response = client.get(f"/api/v1/data/posts/?noteId={','.join([n.note_id for n in note_samples])}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[0].model_dump_json())]} + + def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") assert response.status_code == 200 diff --git a/api/tests/routers/test_system.py b/api/tests/routers/test_system.py index 36751a7..65046e0 100644 --- a/api/tests/routers/test_system.py +++ b/api/tests/routers/test_system.py @@ -1,3 +1,4 @@ +from fastapi import status from fastapi.testclient import TestClient @@ -5,3 +6,25 @@ def test_ping(client: TestClient) -> None: response = client.get("/api/v1/system/ping") assert response.status_code == 200 assert response.json() == {"message": "pong"} + + +def test_allowed_cors(client: TestClient) -> None: + headers = { + "Access-Control-Request-Method": "GET", + "Origin": "http://allowed.example.com", + } + + response = client.options("/api/v1/system/ping", headers=headers) + assert response.status_code == status.HTTP_200_OK + assert response.headers["access-control-allow-origin"] == headers["Origin"] + + +def test_disallowed_cors(client: TestClient) -> None: + headers = { + "Origin": "http://disallowed.example.com", + "Access-Control-Request-Method": "GET", + } + + response = client.options("/api/v1/system/ping", headers=headers) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "access-control-allow-origin" not in response.headers diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 5d19baa..3712d35 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -565,7 +565,7 @@ class NotesValidationDifficulty(str, Enum): empty = "" -class TweetId(UpToNineteenDigitsDecimalString): ... +class PostId(UpToNineteenDigitsDecimalString): ... class NoteData(BaseModel): @@ -576,7 +576,7 @@ class NoteData(BaseModel): note_id: NoteId note_author_participant_id: ParticipantId created_at_millis: TwitterTimestamp - tweet_id: TweetId + tweet_id: PostId believable: NotesBelievable misleading_other: BinaryBool misleading_factual_error: BinaryBool @@ -629,7 +629,7 @@ class SummaryString(NonEmptyTrimmedString): ... class Note(BaseModel): note_id: NoteId - post_id: TweetId + post_id: PostId language: LanguageIdentifier topics: List[Topic] summary: SummaryString @@ -650,9 +650,6 @@ class XUser(BaseModel): following_count: NonNegativeInt -class PostId(UpToNineteenDigitsDecimalString): ... - - MediaDetails: TypeAlias = List[HttpUrl] | None diff --git a/common/birdxplorer_common/settings.py b/common/birdxplorer_common/settings.py index 701a890..923ad98 100644 --- a/common/birdxplorer_common/settings.py +++ b/common/birdxplorer_common/settings.py @@ -18,7 +18,7 @@ class PostgresStorageSettings(BaseSettings): port: int = 5432 database: str = "postgres" - @computed_field # type: ignore[misc] + @computed_field # type: ignore[prop-decorator] @property def sqlalchemy_database_url(self) -> str: return PostgresDsn( @@ -27,7 +27,16 @@ def sqlalchemy_database_url(self) -> str: ).unicode_string() +class CORSSettings(BaseSettings): + allow_credentials: bool = True + allow_methods: list[str] = ["GET"] + allow_headers: list[str] = ["*"] + + allow_origins: list[str] = [] + + class GlobalSettings(BaseSettings): + cors_settings: CORSSettings = Field(default_factory=CORSSettings) model_config = SettingsConfigDict(env_file=".env") logger_settings: LoggerSettings = Field(default_factory=LoggerSettings) storage_settings: PostgresStorageSettings diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index d71a832..2b70590 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -16,7 +16,6 @@ from .models import ( TopicId, TopicLabel, - TweetId, TwitterTimestamp, UserEnrollment, UserId, @@ -39,7 +38,7 @@ class Base(DeclarativeBase): TopicLabel: JSON, NoteId: String, ParticipantId: String, - TweetId: String, + PostId: String, LanguageIdentifier: String, TwitterTimestamp: DECIMAL, SummaryString: String, @@ -49,6 +48,7 @@ class Base(DeclarativeBase): NonNegativeInt: DECIMAL, MediaDetails: JSON, BinaryBool: CHAR, + String: String, } @@ -64,7 +64,7 @@ class NoteRecord(Base): __tablename__ = "notes" note_id: Mapped[NoteId] = mapped_column(primary_key=True) - post_id: Mapped[TweetId] = mapped_column(nullable=False) + post_id: Mapped[PostId] = mapped_column(nullable=False) topics: Mapped[List[NoteTopicAssociation]] = relationship() language: Mapped[LanguageIdentifier] = mapped_column(nullable=False) summary: Mapped[SummaryString] = mapped_column(nullable=False) @@ -91,7 +91,7 @@ class XUserRecord(Base): class PostRecord(Base): __tablename__ = "posts" - post_id: Mapped[TweetId] = mapped_column(primary_key=True) + post_id: Mapped[PostId] = mapped_column(primary_key=True) user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False) user: Mapped[XUserRecord] = relationship() text: Mapped[SummaryString] = mapped_column(nullable=False) @@ -108,7 +108,7 @@ class RowNoteRecord(Base): note_id: Mapped[NoteId] = mapped_column(primary_key=True) note_author_participant_id: Mapped[ParticipantId] = mapped_column(nullable=False) created_at_millis: Mapped[TwitterTimestamp] = mapped_column(nullable=False) - tweet_id: Mapped[TweetId] = mapped_column(nullable=False) + tweet_id: Mapped[PostId] = mapped_column(nullable=False) believable: Mapped[BinaryBool] = mapped_column(nullable=False) misleading_other: Mapped[BinaryBool] = mapped_column(nullable=False) misleading_factual_error: Mapped[BinaryBool] = mapped_column(nullable=False) @@ -128,6 +128,46 @@ class RowNoteRecord(Base): harmful: Mapped[NotesHarmful] = mapped_column(nullable=False) validation_difficulty: Mapped[SummaryString] = mapped_column(nullable=False) summary: Mapped[SummaryString] = mapped_column(nullable=False) + row_post_id: Mapped[PostId] = mapped_column(ForeignKey("row_posts.post_id"), nullable=True) + row_post: Mapped["RowPostRecord"] = relationship("RowPostRecord", back_populates="row_notes") + + +class RowPostRecord(Base): + __tablename__ = "row_posts" + + post_id: Mapped[PostId] = mapped_column(primary_key=True) + author_id: Mapped[UserId] = mapped_column(ForeignKey("row_users.user_id"), nullable=False) + text: Mapped[SummaryString] = mapped_column(nullable=False) + media_type: Mapped[String] = mapped_column(nullable=True) + media_url: Mapped[String] = mapped_column(nullable=True) + created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False) + like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + bookmark_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + quote_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + reply_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + lang: Mapped[String] = mapped_column() + row_notes: Mapped["RowNoteRecord"] = relationship("RowNoteRecord", back_populates="row_post") + user: Mapped["RowUserRecord"] = relationship("RowUserRecord", back_populates="row_post") + + +class RowUserRecord(Base): + __tablename__ = "row_users" + + user_id: Mapped[UserId] = mapped_column(primary_key=True) + name: Mapped[UserName] = mapped_column(nullable=False) + user_name: Mapped[UserName] = mapped_column(nullable=False) + description: Mapped[SummaryString] = mapped_column(nullable=False) + profile_image_url: Mapped[String] = mapped_column(nullable=False) + followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + tweet_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + verified: Mapped[BinaryBool] = mapped_column(nullable=False) + verified_type: Mapped[String] = mapped_column(nullable=False) + location: Mapped[String] = mapped_column(nullable=False) + url: Mapped[String] = mapped_column(nullable=False) + row_post: Mapped["RowPostRecord"] = relationship("RowPostRecord", back_populates="user") class Storage: @@ -183,7 +223,7 @@ def get_notes( created_at_from: Union[None, TwitterTimestamp] = None, created_at_to: Union[None, TwitterTimestamp] = None, topic_ids: Union[List[TopicId], None] = None, - post_ids: Union[List[TweetId], None] = None, + post_ids: Union[List[PostId], None] = None, language: Union[LanguageIdentifier, None] = None, ) -> Generator[NoteModel, None, None]: with Session(self.engine) as sess: @@ -255,6 +295,16 @@ def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostMo for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all(): yield self._post_record_to_model(post_record) + def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]: + query = ( + select(PostRecord) + .join(NoteRecord, NoteRecord.post_id == PostRecord.post_id) + .where(NoteRecord.note_id.in_(note_ids)) + ) + with Session(self.engine) as sess: + for post_record in sess.execute(query).scalars().all(): + yield self._post_record_to_model(post_record) + def gen_storage(settings: GlobalSettings) -> Storage: engine = create_engine(settings.storage_settings.sqlalchemy_database_url) diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 3b74975..ff898e9 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -10,7 +10,6 @@ PostId, Topic, TopicId, - TweetId, TwitterTimestamp, ) from birdxplorer_common.storage import NoteRecord, PostRecord, Storage, TopicRecord @@ -208,8 +207,8 @@ def test_get_notes_by_post_ids( ) -> None: storage = Storage(engine=engine_for_test) post_ids = [ - TweetId.from_str("2234567890123456781"), - TweetId.from_str("2234567890123456782"), + PostId.from_str("2234567890123456781"), + PostId.from_str("2234567890123456782"), ] expected = [note for note in note_samples if note.post_id in post_ids] actual = list(storage.get_notes(post_ids=post_ids)) @@ -222,7 +221,7 @@ def test_get_notes_by_post_ids_empty( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - post_ids: List[TweetId] = [] + post_ids: List[PostId] = [] expected: List[Note] = [] actual = list(storage.get_notes(post_ids=post_ids)) assert expected == actual diff --git a/compose.yml b/compose.yml index 7d36098..df2676f 100644 --- a/compose.yml +++ b/compose.yml @@ -1,4 +1,4 @@ -version: '3.1' +version: "3.1" services: db: @@ -14,7 +14,7 @@ services: timeout: 5s retries: 5 ports: - - '5432:5432' + - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data app: @@ -24,17 +24,17 @@ services: build: args: - ENVIRONMENT=dev - context: ./api - dockerfile: Dockerfile + context: ./ + dockerfile: ./api/Dockerfile.dev env_file: - .env ports: - - '8000:8000' + - "8000:8000" develop: watch: - action: rebuild - path: ./api - target: /app/api + path: ./ + target: /app migrate: depends_on: db: @@ -42,13 +42,12 @@ services: build: args: - ENVIRONMENT=dev - context: ./migrate - dockerfile: Dockerfile + context: ./ + dockerfile: ./migrate/Dockerfile.dev environment: - WAIT_HOSTS=db:5432 env_file: - .env - volumes: postgres_data: diff --git a/etl/.env.example b/etl/.env.example new file mode 100644 index 0000000..3d2cf0e --- /dev/null +++ b/etl/.env.example @@ -0,0 +1 @@ +X_BEARER_TOKEN= \ No newline at end of file diff --git a/etl/README.md b/etl/README.md new file mode 100644 index 0000000..ce6fca3 --- /dev/null +++ b/etl/README.md @@ -0,0 +1,24 @@ +# BirdXplorer ETL + +This is an ETL to get and process data of community notes and X posts, in order to prepare API. + +## Setup development environment + +### Set setting variables + +``` +$ cp .env.example .env +``` + +| Key | Description | +| -------------- | ----------------------- | +| X_BEARER_TOKEN | API key for Twitter API | + +### Run + +``` +$ pwd +/your_dir/BirdXplorer/etl +$ pip install . +$ python src/birdxplorer_etl/main.py +``` diff --git a/etl/pyproject.toml b/etl/pyproject.toml index 618089c..cc45a14 100644 --- a/etl/pyproject.toml +++ b/etl/pyproject.toml @@ -24,7 +24,6 @@ classifiers = [ ] dependencies = [ - "birdxplorer_common[dev] @ git+https://github.com/codeforjapan/BirdXplorer.git@etl/main#subdirectory=common", "pandas", "sqlalchemy", "requests", @@ -48,7 +47,11 @@ dev=[ "flake8", "pyproject-flake8", "pytest", - 'prefect' + 'prefect', + "isort", +] +prod=[ + "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common", ] [tool.pytest.ini_options] @@ -91,11 +94,10 @@ legacy_tox_ini = """ DATA_DIR = {env:BX_DATA_DIR} deps = -e .[dev] + -e ../common commands = - black birdxplorer_etl tests - isort birdxplorer_etl tests + black src/birdxplorer_etl tests + isort src/birdxplorer_etl tests pytest - pflake8 birdxplorer_etl/ tests/ - mypy birdxplorer_etl --strict - mypy tests --strict + pflake8 src/birdxplorer_etl/ tests/ """ diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index a6f3590..81abf65 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -1,13 +1,11 @@ import csv from datetime import datetime, timedelta - import requests import stringcase from prefect import get_run_logger from sqlalchemy.orm import Session - -from birdxplorer_common.storage import RowNoteRecord - +from lib.x.postlookup import lookup +from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord import settings @@ -15,6 +13,10 @@ def extract_data(db: Session): logger = get_run_logger() logger.info("Downloading community notes data") + # get columns of post table + columns = db.query(RowUserRecord).statement.columns.keys() + logger.info(columns) + # Noteデータを取得してSQLiteに保存 date = datetime.now() latest_note = db.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first() @@ -36,20 +38,95 @@ def extract_data(db: Session): reader = csv.DictReader(tsv_data, delimiter="\t") reader.fieldnames = [stringcase.snakecase(field) for field in reader.fieldnames] + rows_to_add = [] for row in reader: - db.add(RowNoteRecord(**row)) + if db.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): + continue + rows_to_add.append(RowNoteRecord(**row)) + db.bulk_save_objects(rows_to_add) + break date = date - timedelta(days=1) db.commit() - row1 = db.query(RowNoteRecord).first() - logger.info(row1) + # Noteに紐づくtweetデータを取得 + postExtract_targetNotes = ( + db.query(RowNoteRecord) + .filter(RowNoteRecord.tweet_id != None) + .filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) + .filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) + .all() + ) + logger.info(len(postExtract_targetNotes)) + for note in postExtract_targetNotes: + tweet_id = note.tweet_id + + is_tweetExist = db.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first() + if is_tweetExist is not None: + logger.info(f"tweet_id {tweet_id} is already exist") + note.row_post_id = tweet_id + continue + + logger.info(tweet_id) + post = lookup(tweet_id) + + if post == None or "data" not in post: + continue + + created_at = datetime.strptime(post["data"]["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ") + created_at_millis = int(created_at.timestamp() * 1000) + + is_userExist = db.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first() + logger.info(is_userExist) + if is_userExist is None: + user_data = ( + post["includes"]["users"][0] + if "includes" in post and "users" in post["includes"] and len(post["includes"]["users"]) > 0 + else {} + ) + db_user = RowUserRecord( + user_id=post["data"]["author_id"], + name=user_data.get("name"), + user_name=user_data.get("username"), + description=user_data.get("description"), + profile_image_url=user_data.get("profile_image_url"), + followers_count=user_data.get("public_metrics", {}).get("followers_count"), + following_count=user_data.get("public_metrics", {}).get("following_count"), + tweet_count=user_data.get("public_metrics", {}).get("tweet_count"), + verified=user_data.get("verified", False), + verified_type=user_data.get("verified_type", ""), + location=user_data.get("location", ""), + url=user_data.get("url", ""), + ) + db.add(db_user) + + media_data = ( + post["includes"]["media"][0] + if "includes" in post and "media" in post["includes"] and len(post["includes"]["media"]) > 0 + else {} + ) + db_post = RowPostRecord( + post_id=post["data"]["id"], + author_id=post["data"]["author_id"], + text=post["data"]["text"], + media_type=media_data.get("type", ""), + media_url=media_data.get("url", ""), + created_at=created_at_millis, + like_count=post["data"]["public_metrics"]["like_count"], + repost_count=post["data"]["public_metrics"]["retweet_count"], + bookmark_count=post["data"]["public_metrics"]["bookmark_count"], + impression_count=post["data"]["public_metrics"]["impression_count"], + quote_count=post["data"]["public_metrics"]["quote_count"], + reply_count=post["data"]["public_metrics"]["reply_count"], + lang=post["data"]["lang"], + ) + db.add(db_post) + note.row_post_id = tweet_id + db.commit() + continue + + # select note from db, get relation tweet and user data + note = db.query(RowNoteRecord).filter(RowNoteRecord.tweet_id == "1797617478950170784").first() - # # Noteに紐づくtweetデータを取得 - # for note in notes_data: - # note_created_at = note.created_at_millis.serialize() - # if note_created_at >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND and note_created_at <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND: # noqa E501 - # tweet_id = note.tweet_id.serialize() - # continue return diff --git a/etl/src/birdxplorer_etl/lib/sqlite/init.py b/etl/src/birdxplorer_etl/lib/sqlite/init.py index 8dd8477..c167352 100644 --- a/etl/src/birdxplorer_etl/lib/sqlite/init.py +++ b/etl/src/birdxplorer_etl/lib/sqlite/init.py @@ -5,7 +5,7 @@ from sqlalchemy import create_engine, inspect from sqlalchemy.orm import sessionmaker -from birdxplorer_common.storage import RowNoteRecord +from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord def init_db(): @@ -18,9 +18,15 @@ def init_db(): # 一時データベースのテーブル作成する # ToDo: noteテーブル以外に必要なものを追加 - if not inspect(engine).has_table("note"): + if not inspect(engine).has_table("row_notes"): logger.info("Creating table note") RowNoteRecord.metadata.create_all(engine) + if not inspect(engine).has_table("row_posts"): + logger.info("Creating table post") + RowPostRecord.metadata.create_all(engine) + if not inspect(engine).has_table("row_users"): + logger.info("Creating table user") + RowUserRecord.metadata.create_all(engine) Session = sessionmaker(bind=engine) diff --git a/etl/src/birdxplorer_etl/lib/x/postlookup.py b/etl/src/birdxplorer_etl/lib/x/postlookup.py new file mode 100644 index 0000000..1410ceb --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/x/postlookup.py @@ -0,0 +1,60 @@ +import requests +import settings +from prefect import get_run_logger +import time + + +def create_url(id): + expansions = "expansions=attachments.poll_ids,attachments.media_keys,author_id,edit_history_tweet_ids,entities.mentions.username,geo.place_id,in_reply_to_user_id,referenced_tweets.id,referenced_tweets.id.author_id" + tweet_fields = "tweet.fields=attachments,author_id,context_annotations,conversation_id,created_at,edit_controls,entities,geo,id,in_reply_to_user_id,lang,public_metrics,possibly_sensitive,referenced_tweets,reply_settings,source,text,withheld" + media_fields = ( + "media.fields=duration_ms,height,media_key,preview_image_url,type,url,width,public_metrics,alt_text,variants" + ) + place_fields = "place.fields=contained_within,country,country_code,full_name,geo,id,name,place_type" + user_fields = "user.fields=created_at,description,entities,id,location,most_recent_tweet_id,name,pinned_tweet_id,profile_image_url,protected,public_metrics,url,username,verified,verified_type,withheld" + + url = "https://api.twitter.com/2/tweets/{}?{}&{}&{}&{}&{}".format( + id, tweet_fields, expansions, media_fields, place_fields, user_fields + ) + return url + + +def bearer_oauth(r): + """ + Method required by bearer token authentication. + """ + + r.headers["Authorization"] = f"Bearer {settings.X_BEARER_TOKEN}" + r.headers["User-Agent"] = "v2TweetLookupPython" + return r + + +def connect_to_endpoint(url): + logger = get_run_logger() + response = requests.request("GET", url, auth=bearer_oauth) + if response.status_code == 429: + limit = response.headers["x-rate-limit-reset"] + logger.info("Waiting for rate limit reset...") + time.sleep(int(limit) - int(time.time()) + 1) + data = connect_to_endpoint(url) + return data + elif response.status_code != 200: + raise Exception("Request returned an error: {} {}".format(response.status_code, response.text)) + return response.json() + + +def check_existence(id): + url = "https://publish.twitter.com/oembed?url=https://x.com/CommunityNotes/status/{}&partner=&hide_thread=false".format( + id + ) + status = requests.get(url).status_code + return status == 200 + + +def lookup(id): + isExist = check_existence(id) + if not isExist: + return None + url = create_url(id) + json_response = connect_to_endpoint(url) + return json_response diff --git a/etl/src/birdxplorer_etl/main.py b/etl/src/birdxplorer_etl/main.py index 4ec2db6..87d131d 100644 --- a/etl/src/birdxplorer_etl/main.py +++ b/etl/src/birdxplorer_etl/main.py @@ -19,8 +19,8 @@ def extract(db: Session): @task -def transform(): - return transform_data() +def transform(db: Session): + transform_data(db) @task @@ -32,8 +32,9 @@ def load(): def run_etl(): i = initialize() _ = extract(i["db"]) - _ = transform() + _ = transform(i["db"]) _ = load() + if __name__ == "__main__": run_etl() diff --git a/etl/src/birdxplorer_etl/settings.py b/etl/src/birdxplorer_etl/settings.py index 9541f55..7a12c70 100644 --- a/etl/src/birdxplorer_etl/settings.py +++ b/etl/src/birdxplorer_etl/settings.py @@ -1,5 +1,14 @@ -TARGET_TWITTER_POST_START_UNIX_MILLISECOND = 1577836800000 -TARGET_TWITTER_POST_END_UNIX_MILLISECOND = 1577836799000 +import os +from dotenv import load_dotenv + +load_dotenv() + +TARGET_TWITTER_POST_START_UNIX_MILLISECOND = int( + os.getenv("TARGET_TWITTER_POST_START_UNIX_MILLISECOND", "1717729500000") +) +TARGET_TWITTER_POST_END_UNIX_MILLISECOND = int(os.getenv("TARGET_TWITTER_POST_END_UNIX_MILLISECOND", "1717729610000")) # Extractで何日前のデータを最新と定義するか。開発中は3日前が楽。 -COMMUNITY_NOTE_DAYS_AGO = 3 +COMMUNITY_NOTE_DAYS_AGO = int(os.getenv("COMMUNITY_NOTE_DAYS_AGO", "3")) + +X_BEARER_TOKEN = os.getenv("X_BEARER_TOKEN") diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index da2d7c8..40cd00b 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -1,6 +1,110 @@ import logging +from sqlalchemy import select, func +from sqlalchemy.orm import Session +from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord +import csv +import os -def transform_data(): +def transform_data(db: Session): logging.info("Transforming data") + + if not os.path.exists("./data/transformed"): + os.makedirs("./data/transformed") + + # Transform row note data and generate note.csv + if os.path.exists("./data/transformed/note.csv"): + os.remove("./data/transformed/note.csv") + + offset = 0 + limit = 1000 + + num_of_notes = db.query(func.count(RowNoteRecord.note_id)).scalar() + + while offset < num_of_notes: + notes = db.execute( + select( + RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary, RowNoteRecord.created_at_millis + ) + .limit(limit) + .offset(offset) + ) + + with open("./data/transformed/note.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["note_id", "post_id", "summary", "created_at"]) + for note in notes: + writer.writerow(note) + offset += limit + + # Transform row post data and generate post.csv + if os.path.exists("./data/transformed/post.csv"): + os.remove("./data/transformed/post.csv") + + offset = 0 + limit = 1000 + + num_of_posts = db.query(func.count(RowPostRecord.post_id)).scalar() + + while offset < num_of_posts: + posts = db.execute( + select( + RowPostRecord.post_id, + RowPostRecord.author_id.label("user_id"), + RowPostRecord.text, + RowPostRecord.created_at, + RowPostRecord.like_count, + RowPostRecord.repost_count, + RowPostRecord.impression_count, + ) + .limit(limit) + .offset(offset) + ) + + with open("./data/transformed/post.csv", "a") as file: + writer = csv.writer(file) + writer.writerow( + ["post_id", "user_id", "text", "created_at", "like_count", "repost_count", "impression_count"] + ) + for post in posts: + writer.writerow(post) + offset += limit + + # Transform row user data and generate user.csv + if os.path.exists("./data/transformed/user.csv"): + os.remove("./data/transformed/user.csv") + + offset = 0 + limit = 1000 + + num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar() + + while offset < num_of_users: + users = db.execute( + select( + RowUserRecord.user_id, + RowUserRecord.user_name.label("name"), + RowUserRecord.profile_image_url.label("profile_image"), + RowUserRecord.followers_count, + RowUserRecord.following_count, + ) + .limit(limit) + .offset(offset) + ) + + with open("./data/transformed/user.csv", "a") as file: + writer = csv.writer(file) + writer.writerow( + [ + "user_id", + "name", + "profile_image", + "followers_count", + "following_count", + ] + ) + for user in users: + writer.writerow(user) + offset += limit + return diff --git a/migrate/Dockerfile.dev b/migrate/Dockerfile.dev new file mode 100644 index 0000000..3057101 --- /dev/null +++ b/migrate/Dockerfile.dev @@ -0,0 +1,52 @@ +ARG PYTHON_VERSION_CODE=3.10 +ARG ENVIRONMENT="dev" +# ENVIRONMENT: dev or prod, refer to project.optional-dependencies in pyproject.toml + +FROM python:${PYTHON_VERSION_CODE}-bookworm as builder +ARG PYTHON_VERSION_CODE +ARG ENVIRONMENT + +WORKDIR /app +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +COPY migrate/pyproject.toml migrate/README.md ./ +COPY migrate/birdxplorer_migration/__init__.py ./birdxplorer_migration/ + +RUN if [ "${ENVIRONMENT}" = "prod" ]; then \ + apt-get update && apt-get install -y --no-install-recommends \ + postgresql-client-15 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/*; \ + fi + +RUN python -m pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -e ".[${ENVIRONMENT}]" + +COPY ../common ./common +RUN if [ "${ENVIRONMENT}" = "dev" ]; then \ + pip install -e ./common; \ + fi + +FROM python:${PYTHON_VERSION_CODE}-slim-bookworm as runner +ARG PYTHON_VERSION_CODE +ARG ENVIRONMENT + +WORKDIR /app + +RUN if [ "${ENVIRONMENT}" = "prod" ]; then \ + apt-get update && apt-get install -y --no-install-recommends \ + libpq5 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/*; \ + fi + +RUN groupadd -r app && useradd -r -g app app +RUN chown -R app:app /app +USER app + +COPY --from=builder /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages +COPY --chown=app:app migrate ./ +COPY ../common ./common + +ENTRYPOINT ["python", "birdxplorer_migration/scripts/migrate_all.py", "birdxplorer_migration/data/appv1/"]