diff --git a/cubbi/session.py b/cubbi/session.py index b459102..ced4cb4 100644 --- a/cubbi/session.py +++ b/cubbi/session.py @@ -2,7 +2,9 @@ Session storage management for Cubbi Container Tool. """ +import fcntl import os +from contextlib import contextmanager from pathlib import Path from typing import Dict, Optional @@ -11,6 +13,31 @@ import yaml DEFAULT_SESSIONS_FILE = Path.home() / ".config" / "cubbi" / "sessions.yaml" +@contextmanager +def _file_lock(file_path: Path): + """Context manager for file locking. + + Args: + file_path: Path to the file to lock + + Yields: + File descriptor with exclusive lock + """ + # Ensure the file exists + file_path.parent.mkdir(parents=True, exist_ok=True) + if not file_path.exists(): + file_path.touch(mode=0o600) + + # Open file and acquire exclusive lock + fd = open(file_path, "r+") + try: + fcntl.flock(fd.fileno(), fcntl.LOCK_EX) + yield fd + finally: + fcntl.flock(fd.fileno(), fcntl.LOCK_UN) + fd.close() + + class SessionManager: """Manager for container sessions.""" @@ -42,9 +69,26 @@ class SessionManager: return sessions def save(self) -> None: - """Save the sessions to file.""" - with open(self.sessions_path, "w") as f: - yaml.safe_dump(self.sessions, f) + """Save the sessions to file. + + Note: This method acquires a file lock and merges with existing data + to prevent concurrent write issues. + """ + with _file_lock(self.sessions_path) as fd: + # Reload sessions from disk to get latest state + fd.seek(0) + sessions = yaml.safe_load(fd) or {} + + # Merge current in-memory sessions with disk state + sessions.update(self.sessions) + + # Write back to file + fd.seek(0) + fd.truncate() + yaml.safe_dump(sessions, fd) + + # Update in-memory cache + self.sessions = sessions def add_session(self, session_id: str, session_data: dict) -> None: """Add a session to storage. @@ -53,8 +97,21 @@ class SessionManager: session_id: The unique session ID session_data: The session data (Session model dump as dict) """ - self.sessions[session_id] = session_data - self.save() + with _file_lock(self.sessions_path) as fd: + # Reload sessions from disk to get latest state + fd.seek(0) + sessions = yaml.safe_load(fd) or {} + + # Apply the modification + sessions[session_id] = session_data + + # Write back to file + fd.seek(0) + fd.truncate() + yaml.safe_dump(sessions, fd) + + # Update in-memory cache + self.sessions = sessions def get_session(self, session_id: str) -> Optional[dict]: """Get a session by ID. @@ -81,6 +138,19 @@ class SessionManager: Args: session_id: The session ID to remove """ - if session_id in self.sessions: - del self.sessions[session_id] - self.save() + with _file_lock(self.sessions_path) as fd: + # Reload sessions from disk to get latest state + fd.seek(0) + sessions = yaml.safe_load(fd) or {} + + # Apply the modification + if session_id in sessions: + del sessions[session_id] + + # Write back to file + fd.seek(0) + fd.truncate() + yaml.safe_dump(sessions, fd) + + # Update in-memory cache + self.sessions = sessions