Websocket fixes.
This commit is contained in:
@ -11,6 +11,7 @@ from app.models import User, SportEvent, SpreadBet, Wallet, Transaction, AdminSe
|
||||
from app.models import EventStatus, SpreadBetStatus, TeamSide, TransactionType, TransactionStatus
|
||||
from app.schemas.spread_bet import SpreadBet as SpreadBetSchema, SpreadBetCreate, SpreadBetDetail
|
||||
from app.routers.auth import get_current_user
|
||||
from app.routers.websocket import broadcast_to_event
|
||||
|
||||
router = APIRouter(prefix="/api/v1/spread-bets", tags=["spread-bets"])
|
||||
|
||||
@ -69,6 +70,20 @@ async def create_spread_bet(
|
||||
db.add(bet)
|
||||
await db.commit()
|
||||
await db.refresh(bet)
|
||||
|
||||
# Broadcast bet created event to all subscribers of this event
|
||||
await broadcast_to_event(
|
||||
bet.event_id,
|
||||
"bet_created",
|
||||
{
|
||||
"bet_id": bet.id,
|
||||
"spread": float(bet.spread),
|
||||
"team": bet.team.value,
|
||||
"stake_amount": float(bet.stake_amount),
|
||||
"creator_username": current_user.username,
|
||||
}
|
||||
)
|
||||
|
||||
return bet
|
||||
|
||||
|
||||
@ -162,6 +177,20 @@ async def take_spread_bet(
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(bet)
|
||||
|
||||
# Broadcast bet taken event to all subscribers of this event
|
||||
await broadcast_to_event(
|
||||
bet.event_id,
|
||||
"bet_taken",
|
||||
{
|
||||
"bet_id": bet.id,
|
||||
"spread": float(bet.spread),
|
||||
"team": bet.team.value,
|
||||
"stake_amount": float(bet.stake_amount),
|
||||
"taker_username": current_user.username,
|
||||
}
|
||||
)
|
||||
|
||||
return bet
|
||||
|
||||
|
||||
@ -284,6 +313,19 @@ async def cancel_spread_bet(
|
||||
if bet.status != SpreadBetStatus.OPEN:
|
||||
raise HTTPException(status_code=400, detail="Can only cancel open bets")
|
||||
|
||||
event_id = bet.event_id
|
||||
bet.status = SpreadBetStatus.CANCELLED
|
||||
await db.commit()
|
||||
|
||||
# Broadcast bet cancelled event to all subscribers of this event
|
||||
await broadcast_to_event(
|
||||
event_id,
|
||||
"bet_cancelled",
|
||||
{
|
||||
"bet_id": bet_id,
|
||||
"spread": float(bet.spread),
|
||||
"team": bet.team.value,
|
||||
}
|
||||
)
|
||||
|
||||
return {"message": "Bet cancelled"}
|
||||
|
||||
@ -1,43 +1,138 @@
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Dict
|
||||
from typing import Dict, Set, Optional
|
||||
from jose import JWTError
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from app.utils.security import decode_token
|
||||
|
||||
router = APIRouter(tags=["websocket"])
|
||||
|
||||
# Store active connections
|
||||
active_connections: Dict[int, WebSocket] = {}
|
||||
# Store active connections by connection_id (unique per connection)
|
||||
active_connections: Dict[str, WebSocket] = {}
|
||||
|
||||
# Store connections subscribed to specific events
|
||||
event_subscriptions: Dict[int, Set[str]] = {} # event_id -> set of connection_ids
|
||||
|
||||
# Map connection_id to websocket
|
||||
connection_websockets: Dict[str, WebSocket] = {}
|
||||
|
||||
|
||||
@router.websocket("/api/v1/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(...),
|
||||
event_id: Optional[int] = Query(None)
|
||||
):
|
||||
await websocket.accept()
|
||||
|
||||
# In a real implementation, you would validate the token here
|
||||
# For MVP, we'll accept all connections
|
||||
user_id = 1 # Placeholder
|
||||
# Generate unique connection ID
|
||||
connection_id = str(uuid.uuid4())
|
||||
|
||||
active_connections[user_id] = websocket
|
||||
# Try to decode token to get user_id (for logging purposes)
|
||||
user_id = None
|
||||
if token and token != 'guest':
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id = payload.get("sub")
|
||||
except (JWTError, Exception):
|
||||
pass # Invalid token, treat as guest
|
||||
|
||||
print(f"[WebSocket] New connection: {connection_id}, user_id: {user_id}, event_id: {event_id}")
|
||||
|
||||
# Store connection
|
||||
active_connections[connection_id] = websocket
|
||||
connection_websockets[connection_id] = websocket
|
||||
|
||||
# Subscribe to event if specified
|
||||
if event_id:
|
||||
if event_id not in event_subscriptions:
|
||||
event_subscriptions[event_id] = set()
|
||||
event_subscriptions[event_id].add(connection_id)
|
||||
print(f"[WebSocket] Subscribed {connection_id} to event {event_id}. Total subscribers: {len(event_subscriptions[event_id])}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Handle incoming messages if needed
|
||||
# Handle incoming messages - could be used to subscribe/unsubscribe
|
||||
try:
|
||||
msg = json.loads(data)
|
||||
if msg.get('action') == 'subscribe' and msg.get('event_id'):
|
||||
eid = msg['event_id']
|
||||
if eid not in event_subscriptions:
|
||||
event_subscriptions[eid] = set()
|
||||
event_subscriptions[eid].add(connection_id)
|
||||
print(f"[WebSocket] {connection_id} subscribed to event {eid}")
|
||||
elif msg.get('action') == 'unsubscribe' and msg.get('event_id'):
|
||||
eid = msg['event_id']
|
||||
if eid in event_subscriptions:
|
||||
event_subscriptions[eid].discard(connection_id)
|
||||
print(f"[WebSocket] {connection_id} unsubscribed from event {eid}")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except WebSocketDisconnect:
|
||||
if user_id in active_connections:
|
||||
del active_connections[user_id]
|
||||
print(f"[WebSocket] Disconnected: {connection_id}")
|
||||
# Clean up connection
|
||||
if connection_id in active_connections:
|
||||
del active_connections[connection_id]
|
||||
if connection_id in connection_websockets:
|
||||
del connection_websockets[connection_id]
|
||||
|
||||
# Clean up event subscriptions
|
||||
for eid, subs in event_subscriptions.items():
|
||||
if connection_id in subs:
|
||||
subs.discard(connection_id)
|
||||
print(f"[WebSocket] Removed {connection_id} from event {eid} subscriptions")
|
||||
|
||||
|
||||
async def broadcast_to_event(event_id: int, event_type: str, data: dict):
|
||||
"""Broadcast a message to all connections subscribed to an event"""
|
||||
message = json.dumps({
|
||||
"type": event_type,
|
||||
"data": {"event_id": event_id, **data}
|
||||
})
|
||||
|
||||
print(f"[WebSocket] Broadcasting {event_type} to event {event_id}")
|
||||
|
||||
if event_id not in event_subscriptions:
|
||||
print(f"[WebSocket] No subscribers for event {event_id}")
|
||||
return
|
||||
|
||||
subscribers = event_subscriptions[event_id].copy()
|
||||
print(f"[WebSocket] Found {len(subscribers)} subscribers for event {event_id}")
|
||||
|
||||
disconnected = set()
|
||||
for conn_id in subscribers:
|
||||
ws = connection_websockets.get(conn_id)
|
||||
if ws:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
print(f"[WebSocket] Sent message to {conn_id}")
|
||||
except Exception as e:
|
||||
print(f"[WebSocket] Failed to send to {conn_id}: {e}")
|
||||
disconnected.add(conn_id)
|
||||
else:
|
||||
print(f"[WebSocket] Connection {conn_id} not found in websockets map")
|
||||
disconnected.add(conn_id)
|
||||
|
||||
# Clean up disconnected connections
|
||||
for conn_id in disconnected:
|
||||
event_subscriptions[event_id].discard(conn_id)
|
||||
if conn_id in active_connections:
|
||||
del active_connections[conn_id]
|
||||
if conn_id in connection_websockets:
|
||||
del connection_websockets[conn_id]
|
||||
|
||||
|
||||
async def broadcast_event(event_type: str, data: dict, user_ids: list[int] = None):
|
||||
"""Broadcast an event to specific users or all connected users"""
|
||||
"""Broadcast an event to all connected users"""
|
||||
message = json.dumps({
|
||||
"type": event_type,
|
||||
"data": data
|
||||
})
|
||||
|
||||
if user_ids:
|
||||
for user_id in user_ids:
|
||||
if user_id in active_connections:
|
||||
await active_connections[user_id].send_text(message)
|
||||
else:
|
||||
for connection in active_connections.values():
|
||||
await connection.send_text(message)
|
||||
for conn_id, ws in list(connection_websockets.items()):
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user