From fcd4862e181e9a892a34b7e1e96c88a308f5ca92 Mon Sep 17 00:00:00 2001 From: Viswamedha Nalabotu Date: Tue, 20 Jan 2026 17:21:28 +0000 Subject: [PATCH] Added agent models, viewsets, consumers, gpu and mcp --- apps/mlstore/consumers.py | 193 +++++++++++ apps/mlstore/migrations/0001_initial.py | 1 - .../migrations/0002_agentmodel_path.py | 15 + apps/mlstore/models.py | 1 + apps/mlstore/routing.py | 6 + apps/mlstore/serializers.py | 51 +++ apps/mlstore/services.py | 61 ++++ apps/mlstore/tasks.py | 267 +++++++++++++++ apps/mlstore/viewsets.py | 29 ++ compose/dev/docker-compose.yml | 14 + config/api.py | 3 + config/asgi.py | 5 +- config/settings.py | 4 + mcp_agent/mcp_server.py | 283 ++++++++++++++-- requirements/django.txt | 1 + requirements/mcp.txt | 1 + src/router/api.ts | 2 + src/router/index.ts | 12 + src/stores/agentStore.ts | 139 ++++++++ src/types/agent.ts | 6 + src/views/AgentDetailView.vue | 303 ++++++++++++++++++ src/views/AgentsView.vue | 92 ++++++ 22 files changed, 1463 insertions(+), 26 deletions(-) create mode 100644 apps/mlstore/consumers.py create mode 100644 apps/mlstore/migrations/0002_agentmodel_path.py create mode 100644 apps/mlstore/routing.py create mode 100644 apps/mlstore/serializers.py create mode 100644 apps/mlstore/services.py create mode 100644 apps/mlstore/tasks.py create mode 100644 apps/mlstore/viewsets.py create mode 100644 src/stores/agentStore.ts create mode 100644 src/types/agent.ts create mode 100644 src/views/AgentDetailView.vue create mode 100644 src/views/AgentsView.vue diff --git a/apps/mlstore/consumers.py b/apps/mlstore/consumers.py new file mode 100644 index 0000000..dc28a87 --- /dev/null +++ b/apps/mlstore/consumers.py @@ -0,0 +1,193 @@ +import json +from channels.generic.websocket import AsyncWebsocketConsumer +from channels.db import database_sync_to_async +from django.utils import timezone +from .models import Agent, AgentRun, AgentEvent +from .tasks import start_fine_tune_run_task, infer_run_task + + +class MLStoreConsumer(AsyncWebsocketConsumer): + async def connect(self): + self.user = self.scope["user"] + self.agent_id = self.scope["url_route"]["kwargs"].get("agent_id") + self.room_group_name = f"mlstore_agent_{self.agent_id}" + + if not self.user.is_authenticated: + await self.close() + return + + agent = await self.get_agent(self.agent_id) + if not agent: + await self.close() + return + + await self.channel_layer.group_add(self.room_group_name, self.channel_name) + await self.accept() + await self.send(json.dumps({ + "type": "connection", + "message": "Connected to mlstore agent stream", + "agent_id": str(self.agent_id) + })) + + async def disconnect(self, close_code): + await self.channel_layer.group_discard(self.room_group_name, self.channel_name) + + async def receive(self, text_data): + try: + data = json.loads(text_data) + action = data.get("action") + + if action == "fine_tune": + await self.handle_fine_tune(data) + elif action == "infer": + await self.handle_infer(data) + elif action in ("stop_agent", "stop"): + await self.handle_stop(data) + else: + await self.send(json.dumps({ + "type": "error", + "message": f"Unknown action: {action}" + })) + except json.JSONDecodeError: + await self.send(json.dumps({ + "type": "error", + "message": "Invalid JSON" + })) + except Exception as e: + await self.send(json.dumps({ + "type": "error", + "message": str(e) + })) + + async def handle_fine_tune(self, data): + agent = await self.get_agent(self.agent_id) + if not agent: + await self.send(json.dumps({ + "type": "error", + "message": "Agent not found" + })) + return + + input_data = data.get("input_data") or {} + execution = await self.create_run(agent, self.user, input_data) + + await self.send(json.dumps({ + "type": "execution_started", + "execution_id": str(execution.uuid), + "agent_id": str(agent.uuid), + "message": f"Fine-tune run {execution.uuid} queued" + })) + + start_fine_tune_run_task.delay(str(execution.uuid)) + + async def handle_infer(self, data): + agent = await self.get_agent(self.agent_id) + if not agent: + await self.send(json.dumps({ + "type": "error", + "message": "Agent not found" + })) + return + + input_data = data.get("input_data") or {} + execution = await self.create_run(agent, self.user, input_data) + + await self.send(json.dumps({ + "type": "execution_started", + "execution_id": str(execution.uuid), + "agent_id": str(agent.uuid), + "message": f"Inference run {execution.uuid} queued" + })) + + infer_run_task.delay(str(execution.uuid)) + + async def handle_stop(self, data): + execution_id = data.get("execution_id") + if not execution_id: + await self.send(json.dumps({ + "type": "error", + "message": "execution_id required to stop" + })) + return + + execution = await self.get_execution(execution_id) + if not execution: + await self.send(json.dumps({ + "type": "error", + "message": "Execution not found" + })) + return + + await self.update_execution_status(execution, "failed") + await self.send(json.dumps({ + "type": "execution_error", + "execution_id": str(execution.uuid), + "error_message": "Execution stopped by user" + })) + + async def mlstore_event(self, event): + await self.send(json.dumps({ + "type": "mlstore_event", + "event_type": event["event_type"], + "content": event["content"], + "timestamp": event["timestamp"] + })) + + async def mlstore_completed(self, event): + await self.send(json.dumps({ + "type": "execution_completed", + "execution_id": event["execution_id"], + "output_data": event["output_data"], + "message": "Execution completed" + })) + + async def mlstore_error(self, event): + await self.send(json.dumps({ + "type": "execution_error", + "execution_id": event["execution_id"], + "error_message": event["error_message"] + })) + + @database_sync_to_async + def get_agent(self, agent_id): + try: + return Agent.objects.get(uuid=agent_id) + except Agent.DoesNotExist: + return None + + @database_sync_to_async + def create_run(self, agent, user, input_data): + return AgentRun.objects.create( + agent=agent, + user=user, + input_data=input_data, + ) + + @database_sync_to_async + def get_execution(self, execution_id): + try: + return AgentRun.objects.get(uuid=execution_id) + except AgentRun.DoesNotExist: + return None + + @database_sync_to_async + def update_execution_status(self, execution, status): + execution.status = status + execution.completed_at = timezone.now() + execution.save() + try: + agent = execution.agent + agent.status = status + agent.completed_at = timezone.now() + agent.save() + except Exception: + pass + return execution + + @database_sync_to_async + def create_event(self, execution, event_type, content): + return AgentEvent.objects.create( + execution=execution, + event_type=event_type, + content=content, + ) diff --git a/apps/mlstore/migrations/0001_initial.py b/apps/mlstore/migrations/0001_initial.py index 6e70efc..d51b45e 100644 --- a/apps/mlstore/migrations/0001_initial.py +++ b/apps/mlstore/migrations/0001_initial.py @@ -3,7 +3,6 @@ import uuid from django.conf import settings from django.db import migrations, models - class Migration(migrations.Migration): initial = True diff --git a/apps/mlstore/migrations/0002_agentmodel_path.py b/apps/mlstore/migrations/0002_agentmodel_path.py new file mode 100644 index 0000000..14649a4 --- /dev/null +++ b/apps/mlstore/migrations/0002_agentmodel_path.py @@ -0,0 +1,15 @@ +from django.db import migrations, models + +class Migration(migrations.Migration): + + dependencies = [ + ('mlstore', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='agentmodel', + name='path', + field=models.CharField(blank=True, default='', max_length=1024), + ), + ] diff --git a/apps/mlstore/models.py b/apps/mlstore/models.py index 109795d..b1b30b5 100644 --- a/apps/mlstore/models.py +++ b/apps/mlstore/models.py @@ -9,6 +9,7 @@ class AgentModel(Model): uuid = UUIDField(default = uuid4, unique = True, editable = False) name = CharField(max_length = 255) version = CharField(max_length = 50) + path = CharField(max_length=1024, blank=True, default='') class Meta: verbose_name = 'Model' diff --git a/apps/mlstore/routing.py b/apps/mlstore/routing.py new file mode 100644 index 0000000..2526d93 --- /dev/null +++ b/apps/mlstore/routing.py @@ -0,0 +1,6 @@ +from django.urls import path +from . import consumers + +websocket_urlpatterns = [ + path("ws/mlstore/agents//", consumers.MLStoreConsumer.as_asgi()), +] diff --git a/apps/mlstore/serializers.py b/apps/mlstore/serializers.py new file mode 100644 index 0000000..13fb140 --- /dev/null +++ b/apps/mlstore/serializers.py @@ -0,0 +1,51 @@ +from rest_framework.serializers import ModelSerializer +from .models import AgentModel, Agent, AgentRun, AgentEvent + + +class AgentModelSerializer(ModelSerializer): + class Meta: + model = AgentModel + fields = ['id', 'uuid', 'name', 'version', 'path'] + read_only_fields = ['id', 'uuid'] + + +class AgentSerializer(ModelSerializer): + model = AgentModelSerializer(read_only=True) + + class Meta: + model = Agent + fields = [ + 'id', + 'uuid', + 'model', + 'status', + 'description', + 'started_at', + 'completed_at', + ] + read_only_fields = ['id', 'uuid', 'started_at', 'completed_at'] + + +class AgentRunSerializer(ModelSerializer): + class Meta: + model = AgentRun + fields = [ + 'id', + 'uuid', + 'agent', + 'user', + 'status', + 'input_data', + 'output_data', + 'error_message', + 'started_at', + 'completed_at', + ] + read_only_fields = ['id', 'uuid', 'started_at', 'completed_at'] + + +class AgentEventSerializer(ModelSerializer): + class Meta: + model = AgentEvent + fields = ['id', 'uuid', 'execution', 'event_type', 'content', 'timestamp'] + read_only_fields = ['id', 'uuid', 'timestamp'] diff --git a/apps/mlstore/services.py b/apps/mlstore/services.py new file mode 100644 index 0000000..1baccad --- /dev/null +++ b/apps/mlstore/services.py @@ -0,0 +1,61 @@ +import asyncio +from typing import Any, Dict, List, Optional +from django.conf import settings +from mcp_agent.mcp_client import MCPClient +from .models import AgentModel + + +async def _call_mcp(tool: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Internal async helper to call the MCP HTTP bridge via MCPClient.""" + server_url = getattr(settings, "MCP_AGENT_URL") + client = MCPClient(server_url) + try: + resp = await client.send(tool, arguments) + return resp + finally: + await client.close() + + +def fine_tune_model( + base_model: str, + training_files: List[str], + hyperparams: Dict[str, Any], + name: str, + version: str, +) -> Dict[str, Any]: + """Synchronously request a fine-tune run on the MCP server. + + Expects the MCP tool `fine_tune` to accept: {base_model, training_files, hyperparams, name, version} + and to return a JSON-like dict containing at least `status` and on success `model_path` and `version`. + """ + return asyncio.run(_call_mcp("fine_tune", { + "base_model": base_model, + "training_files": training_files, + "hyperparams": hyperparams, + "name": name, + "version": version, + })) + + +def load_model_for_inference(model_path: str) -> Dict[str, Any]: + """Tell the MCP server to load a model into memory/serving for inference. + + Expects the MCP tool `load_model` with {model_path} returning status info. + """ + return asyncio.run(_call_mcp("load_model", {"model_path": model_path})) + + +def infer_with_model(model_path: str, prompt: str, options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Request inference from the MCP server using a previously fine-tuned model. + + Calls the MCP tool `infer` with {model_path, prompt, options}. + """ + return asyncio.run(_call_mcp("infer", {"model_path": model_path, "prompt": prompt, "options": options or {}})) + + +def register_model_in_db(name: str, version: str, model_path: str) -> AgentModel: + """Convenience DB helper: create and return an AgentModel record. + + NOTE: migrations are required after the model field change prior to using this in production. + """ + return AgentModel.objects.create(name=name, version=version, path=model_path) diff --git a/apps/mlstore/tasks.py b/apps/mlstore/tasks.py new file mode 100644 index 0000000..0a738f9 --- /dev/null +++ b/apps/mlstore/tasks.py @@ -0,0 +1,267 @@ +from celery import shared_task +from django.utils import timezone +from channels.layers import get_channel_layer +from asgiref.sync import async_to_sync +from . import services +from .models import AgentModel, Agent, AgentRun, AgentEvent +import traceback + + +@shared_task +def start_fine_tune_task(base_model: str, training_files: list, hyperparams: dict, name: str, version: str): + """Start a fine-tune via MCP, and register the resulting model on success. + + This task calls `services.fine_tune_model`, expects a dict result with `status` and on success + `model_path` and optionally `version`. + """ + try: + result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) + + if isinstance(result, dict) and result.get("status") == "completed": + model_path = result.get("model_path") or result.get("path") or "" + model_version = result.get("version") or version + m = AgentModel.objects.create(name=name, version=model_version, path=model_path) + return {"status": "ok", "model_id": m.id, "model_uuid": str(m.uuid), "model_path": model_path, "result": result} + + return {"status": "failed", "result": result} + + except Exception as e: + traceback.print_exc() + return {"status": "error", "error": str(e)} + + +@shared_task +def infer_with_model_task(model_id: int, prompt: str, options: dict = None): + """Run inference by requesting the MCP server to use the stored model. + + Looks up the `AgentModel` by `model_id`, calls `services.infer_with_model`, and returns the response. + """ + try: + model = AgentModel.objects.get(id=model_id) + except AgentModel.DoesNotExist: + return {"status": "error", "error": "model_not_found", "model_id": model_id} + + try: + services.load_model_for_inference(model.path) + except Exception: + pass + + try: + out = services.infer_with_model(model.path, prompt, options or {}) + return {"status": "completed", "model_id": model_id, "response": out} + except Exception as e: + traceback.print_exc() + return {"status": "failed", "error": str(e)} + + +def _send_group_event(room_group_name: str, event_type: str, content: dict): + channel_layer = get_channel_layer() + async_to_sync(channel_layer.group_send)( + room_group_name, + { + "type": "mlstore_event", + "event_type": event_type, + "content": content, + "timestamp": timezone.now().isoformat(), + } + ) + + +def _persist_event(execution: AgentRun, event_type: str, content: dict): + AgentEvent.objects.create( + execution=execution, + event_type=event_type, + content=content, + ) + + +def _update_agent_status(agent: Agent, status: str): + agent.status = status + if status == "running": + agent.started_at = timezone.now() + elif status in ("completed", "failed"): + agent.completed_at = timezone.now() + agent.save() + + +@shared_task +def start_fine_tune_run_task(execution_id: str): + try: + execution = AgentRun.objects.get(uuid=execution_id) + except AgentRun.DoesNotExist: + return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} + + agent = execution.agent + room_group_name = f"mlstore_agent_{agent.uuid}" + + execution.status = "running" + execution.started_at = timezone.now() + execution.save() + _update_agent_status(agent, "running") + + input_data = execution.input_data or {} + base_model = input_data.get("base_model") or agent.model.name + training_files = input_data.get("training_files") or [] + hyperparams = input_data.get("hyperparams") or {} + name = input_data.get("name") or f"{agent.model.name}-ft" + version = input_data.get("version") or "v1" + + _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) + _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "fine_tune"}) + + try: + result = services.fine_tune_model(base_model, training_files, hyperparams, name, version) + if isinstance(result, dict) and result.get("status") == "completed": + model_path = result.get("model_path") or result.get("path") or "" + model_version = result.get("version") or version + new_model = AgentModel.objects.create(name=name, version=model_version, path=model_path) + agent.model = new_model + agent.save() + + execution.status = "completed" + execution.output_data = { + "result": result, + "model_id": new_model.id, + "model_uuid": str(new_model.uuid), + } + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "completed") + + _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) + _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "model_id": new_model.id, "model_path": model_path}) + + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_completed", + "execution_id": str(execution.uuid), + "output_data": execution.output_data, + }, + ) + + return {"status": "completed", "execution_id": execution_id, "model_id": new_model.id} + + execution.status = "failed" + execution.error_message = str(result) + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "failed") + + _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": result}) + _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": result}) + + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_error", + "execution_id": str(execution.uuid), + "error_message": str(result), + }, + ) + return {"status": "failed", "execution_id": execution_id, "result": result} + + except Exception as e: + traceback.print_exc() + execution.status = "failed" + execution.error_message = str(e) + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "failed") + _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) + _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_error", + "execution_id": str(execution.uuid), + "error_message": str(e), + }, + ) + return {"status": "error", "execution_id": execution_id, "error": str(e)} + + +@shared_task +def infer_run_task(execution_id: str): + try: + execution = AgentRun.objects.get(uuid=execution_id) + except AgentRun.DoesNotExist: + return {"status": "error", "error": "execution_not_found", "execution_id": execution_id} + + agent = execution.agent + room_group_name = f"mlstore_agent_{agent.uuid}" + + execution.status = "running" + execution.started_at = timezone.now() + execution.save() + _update_agent_status(agent, "running") + + input_data = execution.input_data or {} + prompt = input_data.get("prompt") or input_data.get("query") or "" + options = input_data.get("options") or {} + + if not prompt: + execution.status = "failed" + execution.error_message = "prompt_required" + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "failed") + _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) + _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": "prompt_required"}) + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_error", + "execution_id": str(execution.uuid), + "error_message": "prompt_required", + }, + ) + return {"status": "failed", "execution_id": execution_id, "error": "prompt_required"} + + _send_group_event(room_group_name, "started", {"execution_id": str(execution.uuid), "action": "infer"}) + _persist_event(execution, "started", {"execution_id": str(execution.uuid), "action": "infer"}) + + try: + try: + services.load_model_for_inference(agent.model.path) + except Exception: + pass + + result = services.infer_with_model(agent.model.path, prompt, options) + + execution.status = "completed" + execution.output_data = {"result": result} + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "completed") + + _send_group_event(room_group_name, "completed", {"execution_id": str(execution.uuid), "result": result}) + _persist_event(execution, "completed", {"execution_id": str(execution.uuid), "result": result}) + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_completed", + "execution_id": str(execution.uuid), + "output_data": execution.output_data, + }, + ) + return {"status": "completed", "execution_id": execution_id} + + except Exception as e: + traceback.print_exc() + execution.status = "failed" + execution.error_message = str(e) + execution.completed_at = timezone.now() + execution.save() + _update_agent_status(agent, "failed") + _send_group_event(room_group_name, "error", {"execution_id": str(execution.uuid), "error": str(e)}) + _persist_event(execution, "error", {"execution_id": str(execution.uuid), "error": str(e)}) + async_to_sync(get_channel_layer().group_send)( + room_group_name, + { + "type": "mlstore_error", + "execution_id": str(execution.uuid), + "error_message": str(e), + }, + ) + return {"status": "failed", "execution_id": execution_id, "error": str(e)} diff --git a/apps/mlstore/viewsets.py b/apps/mlstore/viewsets.py new file mode 100644 index 0000000..2e2132e --- /dev/null +++ b/apps/mlstore/viewsets.py @@ -0,0 +1,29 @@ +from rest_framework.viewsets import ModelViewSet +from rest_framework.permissions import IsAuthenticated +from .models import Agent, AgentRun, AgentEvent +from .serializers import AgentSerializer, AgentRunSerializer, AgentEventSerializer +from rest_framework.decorators import action +from rest_framework.response import Response + +class AgentViewSet(ModelViewSet): + queryset = Agent.objects.all() + serializer_class = AgentSerializer + permission_classes = [IsAuthenticated] + lookup_field = 'uuid' + + +class AgentRunViewSet(ModelViewSet): + queryset = AgentRun.objects.all() + serializer_class = AgentRunSerializer + permission_classes = [IsAuthenticated] + lookup_field = 'uuid' + + def get_queryset(self): + return AgentRun.objects.filter(user=self.request.user) + + @action(detail=True, methods=['get'], url_path='events') + def events(self, request, uuid=None): + run = self.get_object() + events = AgentEvent.objects.filter(execution=run).order_by('timestamp') + serializer = AgentEventSerializer(events, many=True) + return Response(serializer.data) diff --git a/compose/dev/docker-compose.yml b/compose/dev/docker-compose.yml index e431e42..6a400e1 100644 --- a/compose/dev/docker-compose.yml +++ b/compose/dev/docker-compose.yml @@ -91,6 +91,18 @@ services: - ../../.env volumes: - ../../:/app + - ../../notebooks/build:/app/notebooks/build + deploy: + mode: replicated + replicas: 1 + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + environment: + - NVIDIA_VISIBLE_DEVICES=all ports: - "0.0.0.0:8001:8001" depends_on: @@ -100,6 +112,8 @@ services: condition: service_healthy + + volumes: fyp_postgres_data: fyp_redis_data: diff --git a/config/api.py b/config/api.py index 526963e..4ad314a 100644 --- a/config/api.py +++ b/config/api.py @@ -2,9 +2,12 @@ from rest_framework.routers import DefaultRouter from apps.orgs.viewsets import OrganizationViewSet from apps.users.viewsets import UserViewSet +from apps.mlstore.viewsets import AgentViewSet, AgentRunViewSet router = DefaultRouter() router.register(r'user', UserViewSet, basename='user') router.register(r'organization', OrganizationViewSet, basename='organization') +router.register(r'agent', AgentViewSet, basename='agent') +router.register(r'agent-run', AgentRunViewSet, basename='agent-run') urlpatterns = router.urls diff --git a/config/asgi.py b/config/asgi.py index b7de924..9f6e0d7 100644 --- a/config/asgi.py +++ b/config/asgi.py @@ -5,16 +5,17 @@ from django.core.asgi import get_asgi_application from channels.auth import AuthMiddlewareStack from channels.routing import ProtocolTypeRouter, URLRouter from channels.security.websocket import AllowedHostsOriginValidator - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings') django_asgi_app = get_asgi_application() +from apps.mlstore.routing import websocket_urlpatterns + application = ProtocolTypeRouter({ "http": django_asgi_app, "websocket": AllowedHostsOriginValidator( AuthMiddlewareStack( - URLRouter([]) + URLRouter(websocket_urlpatterns) ) ) }) diff --git a/config/settings.py b/config/settings.py index caff210..23af72e 100644 --- a/config/settings.py +++ b/config/settings.py @@ -24,6 +24,10 @@ PARENT_NAME = Path(__file__).resolve().parent.name DJANGO_CELERY_BROKER_URL = os.getenv('DJANGO_CELERY_BROKER_URL', 'redis://localhost:6379/0') +MCP_SERVER_HOST = os.getenv('MCP_SERVER_HOST', 'localhost') +MCP_SERVER_PORT = os.getenv('MCP_SERVER_PORT', '8001') +MCP_AGENT_URL = f"http://{MCP_SERVER_HOST}:{MCP_SERVER_PORT}" + STATIC_URL = os.getenv('DJANGO_STATIC_URL', '/static/') MEDIA_URL = os.getenv('DJANGO_MEDIA_URL', '/media/') STATIC_ROOT = os.getenv('DJANGO_STATIC_ROOT', BASE_DIR / 'static') diff --git a/mcp_agent/mcp_server.py b/mcp_agent/mcp_server.py index ae98b61..1be749f 100644 --- a/mcp_agent/mcp_server.py +++ b/mcp_agent/mcp_server.py @@ -2,11 +2,16 @@ import asyncio import json import os import sys +from datetime import datetime +from pathlib import PureWindowsPath +from typing import Any, Dict, List from aiohttp import web from mcp.server import Server from mcp.types import Tool, TextContent -app = Server("minimal-mcp-server") +app = Server("mlstore-mcp-server") + +LOADED_MODELS: Dict[str, Dict[str, Any]] = {} @app.list_tools() @@ -23,26 +28,263 @@ async def list_tools(): "required": ["message"] }, ) + , + Tool( + name="fine_tune", + description="Start fine-tuning a base model using training files", + inputSchema={ + "type": "object", + "properties": { + "base_model": {"type": "string"}, + "training_files": {"type": "array", "items": {"type": "string"}}, + "hyperparams": {"type": "object"}, + "name": {"type": "string"}, + "version": {"type": "string"} + }, + "required": ["base_model", "training_files", "name", "version"] + }, + ), + Tool( + name="load_model", + description="Load a fine-tuned model into memory for inference", + inputSchema={ + "type": "object", + "properties": { + "model_path": {"type": "string"} + }, + "required": ["model_path"] + }, + ), + Tool( + name="infer", + description="Run inference with a fine-tuned model", + inputSchema={ + "type": "object", + "properties": { + "model_path": {"type": "string"}, + "prompt": {"type": "string"}, + "options": {"type": "object"} + }, + "required": ["model_path", "prompt"] + }, + ), ] +def _now() -> str: + return datetime.utcnow().isoformat() + "Z" + + +def _model_root() -> str: + return os.getenv("MCP_MODEL_DIR") or os.getenv("DJANGO_MODEL_DIR") or os.path.join(os.getcwd(), "model") + + +def _safe_dir_name(name: str) -> str: + return "".join(c for c in name if c.isalnum() or c in ("-", "_", ".")).strip(".") + + +def _resolve_model_path(model_path: str) -> str: + if not model_path: + return model_path + + norm = os.path.normpath(model_path) + if os.path.isabs(norm) and os.path.exists(norm): + return norm + + candidates = [] + + # Try relative to current working directory + candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm))) + + # Try relative to model root + candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm)))) + + # If it's a Windows-style absolute path, map to container /app by trimming common root + if ":" in model_path or "\\" in model_path: + p = PureWindowsPath(model_path) + parts = [str(x) for x in p.parts] + for anchor in ("notebooks", "model"): + if anchor in parts: + idx = parts.index(anchor) + rel = os.path.join(*parts[idx:]) + candidates.append(os.path.normpath(os.path.join(os.getcwd(), rel))) + + for cand in candidates: + if os.path.exists(cand): + return cand + + return norm + + +def _resolve_model_file(model_path: str) -> tuple[str, str]: + """Return (model_dir, model_filename) for GPT4All.""" + resolved = _resolve_model_path(model_path) + if os.path.isdir(resolved): + for name in os.listdir(resolved): + if name.lower().endswith(".gguf"): + return resolved, name + return resolved, "" + return os.path.dirname(resolved), os.path.basename(resolved) + + +async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]: + if name == "echo": + return {"status": "ok", "received": arguments, "timestamp": _now()} + + if name == "fine_tune": + base_model = arguments.get("base_model") + training_files = arguments.get("training_files") or [] + hyperparams = arguments.get("hyperparams") or {} + model_name = arguments.get("name") or "model" + version = arguments.get("version") or "v1" + + model_root = _model_root() + os.makedirs(model_root, exist_ok=True) + + safe_name = _safe_dir_name(model_name) + safe_version = _safe_dir_name(version) + output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}") + os.makedirs(output_dir, exist_ok=True) + + metadata = { + "status": "completed", + "base_model": base_model, + "training_files": training_files, + "hyperparams": hyperparams, + "name": model_name, + "version": version, + "model_path": output_dir, + "timestamp": _now(), + } + try: + with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + except Exception: + pass + + return metadata + + if name == "load_model": + model_path = arguments.get("model_path") + if not model_path: + return {"status": "failed", "error": "model_path_required", "timestamp": _now()} + + model_path = _resolve_model_path(model_path) + + if not os.path.exists(model_path): + return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} + + try: + from gpt4all import GPT4All + + model_dir, model_file = _resolve_model_file(model_path) + if not model_file: + return { + "status": "failed", + "error": "model_file_not_found", + "model_path": model_path, + "timestamp": _now(), + } + + model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu') + LOADED_MODELS[model_path] = { + "loaded_at": _now(), + "model": model, + "model_dir": model_dir, + "model_file": model_file, + } + return { + "status": "completed", + "model_path": model_path, + "loaded": True, + "model_dir": model_dir, + "model_file": model_file, + "timestamp": _now(), + } + except Exception as e: + return { + "status": "failed", + "error": str(e), + "error_type": type(e).__name__, + "model_path": model_path, + "timestamp": _now(), + } + + if name == "infer": + model_path = arguments.get("model_path") + prompt = arguments.get("prompt") or "" + options = arguments.get("options") or {} + + if not model_path: + return {"status": "failed", "error": "model_path_required", "timestamp": _now()} + + model_path = _resolve_model_path(model_path) + + if not os.path.exists(model_path): + return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()} + + try: + if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]: + from gpt4all import GPT4All + + model_dir, model_file = _resolve_model_file(model_path) + if not model_file: + return { + "status": "failed", + "error": "model_file_not_found", + "model_path": model_path, + "timestamp": _now(), + } + model = GPT4All(model_file, model_path=model_dir, allow_download=False) + LOADED_MODELS[model_path] = { + "loaded_at": _now(), + "model": model, + "model_dir": model_dir, + "model_file": model_file, + } + + model = LOADED_MODELS[model_path]["model"] + max_tokens = int(options.get("max_tokens", 256)) + temp = float(options.get("temperature", options.get("temp", 0.7))) + top_p = float(options.get("top_p", 0.95)) + top_k = int(options.get("top_k", 40)) + + response_text = model.generate( + prompt, + max_tokens=max_tokens, + temp=temp, + top_p=top_p, + top_k=top_k, + ) + + return { + "status": "completed", + "model_path": model_path, + "response": response_text, + "options": { + "max_tokens": max_tokens, + "temperature": temp, + "top_p": top_p, + "top_k": top_k, + }, + "timestamp": _now(), + } + except Exception as e: + return { + "status": "failed", + "error": str(e), + "error_type": type(e).__name__, + "model_path": model_path, + "timestamp": _now(), + } + + raise ValueError(f"Unknown tool: {name}") + + @app.call_tool() async def call_tool(name: str, arguments: dict): - if name != "echo": - raise ValueError(f"Unknown tool: {name}") - - return [ - TextContent( - type="text", - text=json.dumps( - { - "received": arguments, - "status": "ok", - }, - indent=2, - ), - ) - ] + result = await _run_tool_http(name, arguments) + return [TextContent(type="text", text=json.dumps(result, indent=2))] async def handle_execute(request: web.Request) -> web.Response: @@ -56,13 +298,8 @@ async def handle_execute(request: web.Request) -> web.Response: {"error": "Missing 'tool' field"}, status=400 ) - result = await call_tool(tool, arguments) - return web.json_response( - { - "tool": tool, - "result": [c.text for c in result], - } - ) + result = await _run_tool_http(tool, arguments) + return web.json_response(result) except json.JSONDecodeError: return web.json_response({"error": "Invalid JSON"}, status=400) diff --git a/requirements/django.txt b/requirements/django.txt index fb5ffbb..1b14dfa 100644 --- a/requirements/django.txt +++ b/requirements/django.txt @@ -25,6 +25,7 @@ django-jazzmin==3.0.1 django-timezone-field==7.2.1 django_celery_results==2.6.0 djangorestframework==3.16.1 +httpx==0.28.1 hyperlink==21.0.0 idna==3.11 Incremental==24.11.0 diff --git a/requirements/mcp.txt b/requirements/mcp.txt index 6c2b8fa..1281e95 100644 --- a/requirements/mcp.txt +++ b/requirements/mcp.txt @@ -4,3 +4,4 @@ pyjwt==2.10.1 python-multipart==0.0.21 sse-starlette==3.2.0 starlette==0.52.1 +gpt4all==2.8.2 \ No newline at end of file diff --git a/src/router/api.ts b/src/router/api.ts index 71d1d01..96248ec 100644 --- a/src/router/api.ts +++ b/src/router/api.ts @@ -93,6 +93,8 @@ export const API = { `/api/organization/${orgUuid}/create-invite/?max_uses=${max_uses}`, organizationJoin: (token: string) => `/api/organization/join/${token}/`, organizationLeave: (orgUuid: string) => `/api/organization/${orgUuid}/leave/`, + agents: () => '/api/agent/', + agent: (id: string) => `/api/agent/${id}/`, } export const apiClient = new ApiClient() diff --git a/src/router/index.ts b/src/router/index.ts index 72dfcbc..936454a 100644 --- a/src/router/index.ts +++ b/src/router/index.ts @@ -55,6 +55,18 @@ const router = createRouter({ name: 'invite-accept', component: () => import('../views/InviteAccept.vue'), }, + { + path: '/agents', + name: 'agents', + component: () => import('../views/AgentsView.vue'), + meta: { requiresAuth: true, requiresManager: true }, + }, + { + path: '/agents/:id', + name: 'agent-detail', + component: () => import('../views/AgentDetailView.vue'), + meta: { requiresAuth: true, requiresManager: true }, + }, ], }) diff --git a/src/stores/agentStore.ts b/src/stores/agentStore.ts new file mode 100644 index 0000000..841f4d3 --- /dev/null +++ b/src/stores/agentStore.ts @@ -0,0 +1,139 @@ +import { defineStore } from 'pinia' +import { ref } from 'vue' +import type { AgentEvent } from '../types/agent' + +export const useAgentStore = defineStore('agent', () => { + const isConnected = ref(false) + const executionStatus = ref<'idle' | 'running' | 'completed' | 'failed'>('idle') + const eventLog = ref([]) + const lastExecutionId = ref(null) + + let socket: WebSocket | null = null + + const pushEvent = (evt: { type: string; message?: string; content?: unknown; timestamp?: string }) => { + eventLog.value.unshift({ + type: evt.type, + message: evt.message, + content: evt.content, + timestamp: evt.timestamp ? new Date(evt.timestamp) : new Date(), + }) + } + + const connect = (agentId: string) => { + if (socket) { + socket.close() + socket = null + } + + const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws' + const wsUrl = `${wsProtocol}://${window.location.host}/ws/mlstore/agents/${agentId}/` + socket = new WebSocket(wsUrl) + + socket.onopen = () => { + isConnected.value = true + pushEvent({ type: 'connected', message: 'WebSocket connected' }) + } + + socket.onmessage = (event) => { + try { + const payload = JSON.parse(event.data) + const type = payload.type + + if (type === 'connection') { + pushEvent({ type: 'connection', message: payload.message, content: payload }) + return + } + + if (type === 'execution_started') { + executionStatus.value = 'running' + lastExecutionId.value = payload.execution_id + pushEvent({ type: 'started', message: payload.message, content: payload }) + return + } + + if (type === 'mlstore_event') { + const eventType = payload.event_type || 'message' + pushEvent({ + type: eventType, + content: payload.content, + timestamp: payload.timestamp, + }) + return + } + + if (type === 'execution_completed') { + executionStatus.value = 'completed' + pushEvent({ type: 'completed', content: payload.output_data, timestamp: payload.timestamp }) + return + } + + if (type === 'execution_error') { + executionStatus.value = 'failed' + pushEvent({ type: 'error', message: payload.error_message, content: payload }) + return + } + + pushEvent({ type: 'message', content: payload }) + } catch { + pushEvent({ type: 'error', message: 'Invalid message received', content: event.data }) + } + } + + socket.onerror = () => { + pushEvent({ type: 'error', message: 'WebSocket error' }) + } + + socket.onclose = () => { + isConnected.value = false + pushEvent({ type: 'disconnected', message: 'WebSocket disconnected' }) + } + } + + const disconnect = () => { + if (socket) { + socket.close() + socket = null + } + isConnected.value = false + } + + const startAgent = (data: { query?: string; prompt?: string; options?: Record }) => { + if (!socket || socket.readyState !== WebSocket.OPEN) return + const prompt = data.query ?? data.prompt ?? '' + socket.send( + JSON.stringify({ + action: 'infer', + input_data: { + prompt, + options: data.options ?? {}, + }, + }), + ) + } + + const stopAgent = (executionId?: string) => { + if (!socket || socket.readyState !== WebSocket.OPEN) return + socket.send( + JSON.stringify({ + action: 'stop_agent', + execution_id: executionId ?? lastExecutionId.value, + }), + ) + } + + const resetLog = () => { + eventLog.value = [] + } + + return { + isConnected, + executionStatus, + eventLog, + connect, + disconnect, + startAgent, + stopAgent, + resetLog, + lastExecutionId, + } +}) diff --git a/src/types/agent.ts b/src/types/agent.ts new file mode 100644 index 0000000..a626567 --- /dev/null +++ b/src/types/agent.ts @@ -0,0 +1,6 @@ +export type AgentEvent = { + type: string + timestamp: Date + message?: string + content?: unknown +} diff --git a/src/views/AgentDetailView.vue b/src/views/AgentDetailView.vue new file mode 100644 index 0000000..e407418 --- /dev/null +++ b/src/views/AgentDetailView.vue @@ -0,0 +1,303 @@ + + + + + diff --git a/src/views/AgentsView.vue b/src/views/AgentsView.vue new file mode 100644 index 0000000..b691890 --- /dev/null +++ b/src/views/AgentsView.vue @@ -0,0 +1,92 @@ + + + + +