fix: user migration confusion with user_id (#401)

+ added tests
This commit is contained in:
2024-09-03 22:07:36 +02:00
committed by GitHub
parent a358bfadf1
commit 873cbb0a42
2 changed files with 110 additions and 7 deletions

View File

@@ -31,20 +31,26 @@ async def migrate_user(email, user_id):
if not user_ids:
return
for user_id in user_ids:
# do not migrate back
if user_id in user_ids:
return
for old_user_id in user_ids:
query = (
transcripts.update()
.where(transcripts.c.user_id == user_id)
.where(transcripts.c.user_id == old_user_id)
.values(user_id=user_id)
)
await database.execute(query)
query = rooms.update().where(rooms.c.user_id == user_id).values(user_id=user_id)
query = (
rooms.update().where(rooms.c.user_id == old_user_id).values(user_id=user_id)
)
await database.execute(query)
query = (
meetings.update()
.where(meetings.c.user_id == user_id)
.where(meetings.c.user_id == old_user_id)
.values(user_id=user_id)
)
await database.execute(query)

View File

@@ -1,4 +1,6 @@
import pytest
from unittest.mock import patch
from contextlib import asynccontextmanager
from httpx import AsyncClient
@@ -144,9 +146,8 @@ async def test_transcripts_list_anonymous():
settings.PUBLIC_MODE = False
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
@asynccontextmanager
async def authenticated_client_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
@@ -163,6 +164,38 @@ async def authenticated_client():
del app.dependency_overrides[current_user_optional]
@asynccontextmanager
async def authenticated_client2_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
async with authenticated_client_ctx():
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client2():
async with authenticated_client2_ctx():
yield
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client):
# XXX this test is a bit fragile, as it depends on the storage which
@@ -228,3 +261,67 @@ async def test_transcript_mark_reviewed():
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True
@asynccontextmanager
async def patch_migrate_user():
with patch(
"reflector.db.migrate_user.users_to_migrate",
[["test@mail.com", "randomuserid", None]],
):
yield
@pytest.mark.asyncio
async def test_transcripts_list_authenticated_migration():
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
testx1 = "testmigration1"
testx2 = "testmigration2"
async with patch_migrate_user(), AsyncClient(
app=app, base_url="http://test/v1"
) as ac:
# first ensure client 2 does not have any transcripts related to this test
async with authenticated_client2_ctx():
response = await ac.get("/transcripts")
assert response.status_code == 200
# assert len(response.json()["items"]) == 0
names = [t["name"] for t in response.json()["items"]]
assert testx1 not in names
assert testx2 not in names
# create 2 transcripts with client 1
async with authenticated_client_ctx():
response = await ac.post("/transcripts", json={"name": testx1})
assert response.status_code == 200
assert response.json()["name"] == testx1
response = await ac.post("/transcripts", json={"name": testx2})
assert response.status_code == 200
assert response.json()["name"] == testx2
response = await ac.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert testx1 in names
assert testx2 in names
# now going back to client 2, migration should happen
async with authenticated_client2_ctx():
response = await ac.get("/transcripts")
assert response.status_code == 200
names = [t["name"] for t in response.json()["items"]]
assert testx1 in names
assert testx2 in names
# and client 1 should have nothing now
async with authenticated_client_ctx():
response = await ac.get("/transcripts")
assert response.status_code == 200
names = [t["name"] for t in response.json()["items"]]
assert testx1 not in names
assert testx2 not in names