Websocket fixes.

This commit is contained in:
2026-01-11 00:46:49 -06:00
parent d4855040d8
commit 174abb7f56
32 changed files with 2770 additions and 32 deletions

View File

@ -11,6 +11,7 @@ from app.models import User, SportEvent, SpreadBet, Wallet, Transaction, AdminSe
from app.models import EventStatus, SpreadBetStatus, TeamSide, TransactionType, TransactionStatus
from app.schemas.spread_bet import SpreadBet as SpreadBetSchema, SpreadBetCreate, SpreadBetDetail
from app.routers.auth import get_current_user
from app.routers.websocket import broadcast_to_event
router = APIRouter(prefix="/api/v1/spread-bets", tags=["spread-bets"])
@ -69,6 +70,20 @@ async def create_spread_bet(
db.add(bet)
await db.commit()
await db.refresh(bet)
# Broadcast bet created event to all subscribers of this event
await broadcast_to_event(
bet.event_id,
"bet_created",
{
"bet_id": bet.id,
"spread": float(bet.spread),
"team": bet.team.value,
"stake_amount": float(bet.stake_amount),
"creator_username": current_user.username,
}
)
return bet
@ -162,6 +177,20 @@ async def take_spread_bet(
await db.commit()
await db.refresh(bet)
# Broadcast bet taken event to all subscribers of this event
await broadcast_to_event(
bet.event_id,
"bet_taken",
{
"bet_id": bet.id,
"spread": float(bet.spread),
"team": bet.team.value,
"stake_amount": float(bet.stake_amount),
"taker_username": current_user.username,
}
)
return bet
@ -284,6 +313,19 @@ async def cancel_spread_bet(
if bet.status != SpreadBetStatus.OPEN:
raise HTTPException(status_code=400, detail="Can only cancel open bets")
event_id = bet.event_id
bet.status = SpreadBetStatus.CANCELLED
await db.commit()
# Broadcast bet cancelled event to all subscribers of this event
await broadcast_to_event(
event_id,
"bet_cancelled",
{
"bet_id": bet_id,
"spread": float(bet.spread),
"team": bet.team.value,
}
)
return {"message": "Bet cancelled"}

View File

@ -1,43 +1,138 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Dict
from typing import Dict, Set, Optional
from jose import JWTError
import json
import uuid
from app.utils.security import decode_token
router = APIRouter(tags=["websocket"])
# Store active connections
active_connections: Dict[int, WebSocket] = {}
# Store active connections by connection_id (unique per connection)
active_connections: Dict[str, WebSocket] = {}
# Store connections subscribed to specific events
event_subscriptions: Dict[int, Set[str]] = {} # event_id -> set of connection_ids
# Map connection_id to websocket
connection_websockets: Dict[str, WebSocket] = {}
@router.websocket("/api/v1/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(...),
event_id: Optional[int] = Query(None)
):
await websocket.accept()
# In a real implementation, you would validate the token here
# For MVP, we'll accept all connections
user_id = 1 # Placeholder
# Generate unique connection ID
connection_id = str(uuid.uuid4())
active_connections[user_id] = websocket
# Try to decode token to get user_id (for logging purposes)
user_id = None
if token and token != 'guest':
try:
payload = decode_token(token)
user_id = payload.get("sub")
except (JWTError, Exception):
pass # Invalid token, treat as guest
print(f"[WebSocket] New connection: {connection_id}, user_id: {user_id}, event_id: {event_id}")
# Store connection
active_connections[connection_id] = websocket
connection_websockets[connection_id] = websocket
# Subscribe to event if specified
if event_id:
if event_id not in event_subscriptions:
event_subscriptions[event_id] = set()
event_subscriptions[event_id].add(connection_id)
print(f"[WebSocket] Subscribed {connection_id} to event {event_id}. Total subscribers: {len(event_subscriptions[event_id])}")
try:
while True:
data = await websocket.receive_text()
# Handle incoming messages if needed
# Handle incoming messages - could be used to subscribe/unsubscribe
try:
msg = json.loads(data)
if msg.get('action') == 'subscribe' and msg.get('event_id'):
eid = msg['event_id']
if eid not in event_subscriptions:
event_subscriptions[eid] = set()
event_subscriptions[eid].add(connection_id)
print(f"[WebSocket] {connection_id} subscribed to event {eid}")
elif msg.get('action') == 'unsubscribe' and msg.get('event_id'):
eid = msg['event_id']
if eid in event_subscriptions:
event_subscriptions[eid].discard(connection_id)
print(f"[WebSocket] {connection_id} unsubscribed from event {eid}")
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
if user_id in active_connections:
del active_connections[user_id]
print(f"[WebSocket] Disconnected: {connection_id}")
# Clean up connection
if connection_id in active_connections:
del active_connections[connection_id]
if connection_id in connection_websockets:
del connection_websockets[connection_id]
# Clean up event subscriptions
for eid, subs in event_subscriptions.items():
if connection_id in subs:
subs.discard(connection_id)
print(f"[WebSocket] Removed {connection_id} from event {eid} subscriptions")
async def broadcast_to_event(event_id: int, event_type: str, data: dict):
"""Broadcast a message to all connections subscribed to an event"""
message = json.dumps({
"type": event_type,
"data": {"event_id": event_id, **data}
})
print(f"[WebSocket] Broadcasting {event_type} to event {event_id}")
if event_id not in event_subscriptions:
print(f"[WebSocket] No subscribers for event {event_id}")
return
subscribers = event_subscriptions[event_id].copy()
print(f"[WebSocket] Found {len(subscribers)} subscribers for event {event_id}")
disconnected = set()
for conn_id in subscribers:
ws = connection_websockets.get(conn_id)
if ws:
try:
await ws.send_text(message)
print(f"[WebSocket] Sent message to {conn_id}")
except Exception as e:
print(f"[WebSocket] Failed to send to {conn_id}: {e}")
disconnected.add(conn_id)
else:
print(f"[WebSocket] Connection {conn_id} not found in websockets map")
disconnected.add(conn_id)
# Clean up disconnected connections
for conn_id in disconnected:
event_subscriptions[event_id].discard(conn_id)
if conn_id in active_connections:
del active_connections[conn_id]
if conn_id in connection_websockets:
del connection_websockets[conn_id]
async def broadcast_event(event_type: str, data: dict, user_ids: list[int] = None):
"""Broadcast an event to specific users or all connected users"""
"""Broadcast an event to all connected users"""
message = json.dumps({
"type": event_type,
"data": data
})
if user_ids:
for user_id in user_ids:
if user_id in active_connections:
await active_connections[user_id].send_text(message)
else:
for connection in active_connections.values():
await connection.send_text(message)
for conn_id, ws in list(connection_websockets.items()):
try:
await ws.send_text(message)
except Exception:
pass