From 873cbb0a42eb3c9de4d029014832dfb0040a138e Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Tue, 3 Sep 2024 22:07:36 +0200 Subject: [PATCH] fix: user migration confusion with user_id (#401) + added tests --- server/reflector/db/migrate_user.py | 14 ++-- server/tests/test_transcripts.py | 103 +++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/server/reflector/db/migrate_user.py b/server/reflector/db/migrate_user.py index 507e8f96..f058bf21 100644 --- a/server/reflector/db/migrate_user.py +++ b/server/reflector/db/migrate_user.py @@ -31,20 +31,26 @@ async def migrate_user(email, user_id): if not user_ids: return - for user_id in user_ids: + # do not migrate back + if user_id in user_ids: + return + + for old_user_id in user_ids: query = ( transcripts.update() - .where(transcripts.c.user_id == user_id) + .where(transcripts.c.user_id == old_user_id) .values(user_id=user_id) ) await database.execute(query) - query = rooms.update().where(rooms.c.user_id == user_id).values(user_id=user_id) + query = ( + rooms.update().where(rooms.c.user_id == old_user_id).values(user_id=user_id) + ) await database.execute(query) query = ( meetings.update() - .where(meetings.c.user_id == user_id) + .where(meetings.c.user_id == old_user_id) .values(user_id=user_id) ) await database.execute(query) diff --git a/server/tests/test_transcripts.py b/server/tests/test_transcripts.py index 5fac353a..be0779ff 100644 --- a/server/tests/test_transcripts.py +++ b/server/tests/test_transcripts.py @@ -1,4 +1,6 @@ import pytest +from unittest.mock import patch +from contextlib import asynccontextmanager from httpx import AsyncClient @@ -144,9 +146,8 @@ async def test_transcripts_list_anonymous(): settings.PUBLIC_MODE = False -@pytest.fixture -@pytest.mark.asyncio -async def authenticated_client(): +@asynccontextmanager +async def authenticated_client_ctx(): from reflector.app import app from reflector.auth import current_user, current_user_optional @@ -163,6 +164,38 @@ async def authenticated_client(): del app.dependency_overrides[current_user_optional] +@asynccontextmanager +async def authenticated_client2_ctx(): + from reflector.app import app + from reflector.auth import current_user, current_user_optional + + app.dependency_overrides[current_user] = lambda: { + "sub": "randomuserid2", + "email": "test@mail.com", + } + app.dependency_overrides[current_user_optional] = lambda: { + "sub": "randomuserid2", + "email": "test@mail.com", + } + yield + del app.dependency_overrides[current_user] + del app.dependency_overrides[current_user_optional] + + +@pytest.fixture +@pytest.mark.asyncio +async def authenticated_client(): + async with authenticated_client_ctx(): + yield + + +@pytest.fixture +@pytest.mark.asyncio +async def authenticated_client2(): + async with authenticated_client2_ctx(): + yield + + @pytest.mark.asyncio async def test_transcripts_list_authenticated(authenticated_client): # XXX this test is a bit fragile, as it depends on the storage which @@ -228,3 +261,67 @@ async def test_transcript_mark_reviewed(): response = await ac.get(f"/transcripts/{tid}") assert response.status_code == 200 assert response.json()["reviewed"] is True + + +@asynccontextmanager +async def patch_migrate_user(): + with patch( + "reflector.db.migrate_user.users_to_migrate", + [["test@mail.com", "randomuserid", None]], + ): + yield + + +@pytest.mark.asyncio +async def test_transcripts_list_authenticated_migration(): + # XXX this test is a bit fragile, as it depends on the storage which + # is shared between tests + from reflector.app import app + + testx1 = "testmigration1" + testx2 = "testmigration2" + + async with patch_migrate_user(), AsyncClient( + app=app, base_url="http://test/v1" + ) as ac: + # first ensure client 2 does not have any transcripts related to this test + async with authenticated_client2_ctx(): + response = await ac.get("/transcripts") + assert response.status_code == 200 + # assert len(response.json()["items"]) == 0 + names = [t["name"] for t in response.json()["items"]] + assert testx1 not in names + assert testx2 not in names + + # create 2 transcripts with client 1 + async with authenticated_client_ctx(): + response = await ac.post("/transcripts", json={"name": testx1}) + assert response.status_code == 200 + assert response.json()["name"] == testx1 + + response = await ac.post("/transcripts", json={"name": testx2}) + assert response.status_code == 200 + assert response.json()["name"] == testx2 + + response = await ac.get("/transcripts") + assert response.status_code == 200 + assert len(response.json()["items"]) >= 2 + names = [t["name"] for t in response.json()["items"]] + assert testx1 in names + assert testx2 in names + + # now going back to client 2, migration should happen + async with authenticated_client2_ctx(): + response = await ac.get("/transcripts") + assert response.status_code == 200 + names = [t["name"] for t in response.json()["items"]] + assert testx1 in names + assert testx2 in names + + # and client 1 should have nothing now + async with authenticated_client_ctx(): + response = await ac.get("/transcripts") + assert response.status_code == 200 + names = [t["name"] for t in response.json()["items"]] + assert testx1 not in names + assert testx2 not in names