diff --git a/posthog/api/query.py b/posthog/api/query.py index edabae8beb7a8c..861afd3510fc99 100644 --- a/posthog/api/query.py +++ b/posthog/api/query.py @@ -46,6 +46,7 @@ QueryResponseAlternative, QueryStatusResponse, ) +from typing import cast class QueryViewSet(TeamAndOrgViewSetMixin, PydanticModelMixin, viewsets.ViewSet): @@ -204,7 +205,7 @@ async def query_async(request: Request, *args, **kwargs) -> HttpResponse: response = await sync_to_async(view)(request) if response.status_code != 200: - return HttpResponse(response.rendered_content, status=response.status_code) + return response response.render() data = json.loads(response.rendered_content) @@ -215,7 +216,8 @@ async def query_async(request: Request, *args, **kwargs) -> HttpResponse: # For async responses, poll until complete or timeout async def check_query_status(): - manager = QueryStatusManager(data["query_id"], kwargs.get("project_id")) + assert kwargs.get("project_id") is not None + manager = QueryStatusManager(data["query_id"], cast(int, kwargs["project_id"])) start_time = time.time() sleep_time = 0.1 # Start with 100ms max_sleep_time = 1.0 # Don't wait more than 1 second between checks diff --git a/posthog/api/routing.py b/posthog/api/routing.py index dfd627e8667a0c..0262d1ced8812d 100644 --- a/posthog/api/routing.py +++ b/posthog/api/routing.py @@ -252,37 +252,41 @@ def team_id(self) -> int: @cached_property def team(self) -> Team: if team_from_token := self._get_team_from_request(): - return team_from_token - - if self._is_project_view: - return Team.objects.get( + team = team_from_token + elif self._is_project_view: + team = Team.objects.get( id=self.project_id # KLUDGE: This is just for the period of transition to project environments ) - - if self.param_derived_from_user_current_team == "team_id": + elif self.param_derived_from_user_current_team == "team_id": user = cast(User, self.request.user) + assert user.team is not None team = user.team - assert team is not None - return team - try: - return Team.objects.get(id=self.team_id) - except Team.DoesNotExist: - raise NotFound( - detail="Project not found." # TODO: "Environment" instead of "Project" when project environments are rolled out - ) + else: + try: + team = Team.objects.get(id=self.team_id) + except Team.DoesNotExist: + raise NotFound( + detail="Project not found." # TODO: "Environment" instead of "Project" when project environments are rolled out + ) + + tag_queries(team_id=team.pk) + return team @cached_property def project_id(self) -> int: if team_from_token := self._get_team_from_request(): - return team_from_token.project_id + project_id = team_from_token.project_id - if self.param_derived_from_user_current_team == "project_id": + elif self.param_derived_from_user_current_team == "project_id": user = cast(User, self.request.user) team = user.team assert team is not None - return team.project_id + project_id = team.project_id + else: + project_id = self.parents_query_dict["project_id"] - return self.parents_query_dict["project_id"] + tag_queries(team_id=project_id) + return project_id @cached_property def project(self) -> Project: diff --git a/posthog/api/test/test_query.py b/posthog/api/test/test_query.py index 5032be94570615..b10029cdf2bbf7 100644 --- a/posthog/api/test/test_query.py +++ b/posthog/api/test/test_query.py @@ -1022,7 +1022,7 @@ def test_async_query_invalid_json(self): f"/api/environments/{self.team.pk}/query_async/", "invalid json", content_type="application/json" ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {"error": "Invalid JSON in request body"}) + self.assertEqual(response.json()["type"], "invalid_request") def test_async_auth(self): self.client.logout()