Added code for fine tuning
This commit is contained in:
parent
55b4300fdd
commit
bf58336f58
2 changed files with 794 additions and 44 deletions
|
|
@ -1,12 +1,18 @@
|
|||
import httpx
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self, server_url: str):
|
||||
def __init__(self, server_url: str, timeout: int = 3600):
|
||||
self.server_url = server_url
|
||||
self.client = httpx.AsyncClient(timeout=60)
|
||||
self.client = httpx.AsyncClient(timeout=timeout)
|
||||
logger.info(f"MCPClient initialized for {server_url} with timeout={timeout}s ({timeout//60} minutes)")
|
||||
|
||||
async def send(self, tool: str, arguments: dict):
|
||||
logger.info(f"MCPClient: Sending request to {self.server_url}/execute for tool '{tool}'")
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.server_url}/execute",
|
||||
json={
|
||||
|
|
@ -14,8 +20,47 @@ class MCPClient:
|
|||
"arguments": arguments,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
logger.info(f"MCPClient: Received response with status={response.status_code}")
|
||||
logger.debug(f"MCPClient: Response headers: {response.headers}")
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"MCPClient: Request timeout for tool '{tool}': {str(e)}")
|
||||
raise Exception(f"MCP tool '{tool}' request timed out (>3600s / 1 hour). Model loading or fine-tuning may be too slow.")
|
||||
except Exception as e:
|
||||
logger.error(f"MCPClient: Request failed for tool '{tool}': {str(e)}", exc_info=True)
|
||||
raise Exception(f"MCP tool '{tool}' request failed: {str(e)}")
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_data = {}
|
||||
try:
|
||||
error_data = response.json()
|
||||
logger.error(f"MCPClient: HTTP error {response.status_code}: {error_data}")
|
||||
except:
|
||||
logger.error(f"MCPClient: HTTP error {response.status_code} (could not parse JSON)")
|
||||
pass
|
||||
|
||||
error_msg = error_data.get("error") or error_data.get("details") or f"HTTP {response.status_code}"
|
||||
raise Exception(f"MCP tool '{tool}' failed: {error_msg}. Full response: {error_data}")
|
||||
|
||||
try:
|
||||
result = response.json()
|
||||
logger.debug(f"MCPClient: Parsed JSON response: status={result.get('status')}")
|
||||
except Exception as e:
|
||||
logger.error(f"MCPClient: Failed to parse response JSON: {str(e)}")
|
||||
logger.error(f"MCPClient: Raw response text: {response.text[:500]}")
|
||||
raise Exception(f"MCP tool '{tool}' returned invalid JSON: {str(e)}")
|
||||
|
||||
if isinstance(result, dict) and result.get("status") == "failed":
|
||||
error_msg = result.get("error") or result.get("details") or "Unknown error"
|
||||
traceback_info = result.get("traceback", "")
|
||||
full_error = f"MCP tool '{tool}' returned failure: {error_msg}"
|
||||
if traceback_info:
|
||||
full_error += f"\n\nServer traceback:\n{traceback_info}"
|
||||
logger.error(f"MCPClient: {full_error}")
|
||||
raise Exception(full_error)
|
||||
|
||||
logger.info(f"MCPClient: Tool '{tool}' completed successfully")
|
||||
return result
|
||||
|
||||
async def health(self):
|
||||
response = await self.client.get(f"{self.server_url}/health")
|
||||
|
|
@ -37,6 +82,5 @@ async def main():
|
|||
print(result)
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -2,21 +2,47 @@ import asyncio
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import PureWindowsPath
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from aiohttp import web
|
||||
from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stderr),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
model_cache_dir = os.path.join(project_root, "model", "base-model")
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
os.environ["HF_HOME"] = model_cache_dir
|
||||
logger.info(f"Project root: {project_root}")
|
||||
logger.info(f"HuggingFace model cache directory set to: {model_cache_dir}")
|
||||
|
||||
app = Server("mlstore-mcp-server")
|
||||
logger.info("MCP Server initialized: mlstore-mcp-server")
|
||||
|
||||
LOADED_MODELS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
PAIR_EXTRACTOR: Dict[str, Any] = {}
|
||||
|
||||
BASE_MODEL_CACHE_DIR = model_cache_dir
|
||||
|
||||
|
||||
|
||||
@app.list_tools()
|
||||
async def list_tools():
|
||||
return [
|
||||
logger.info("Listing available tools")
|
||||
tools = [
|
||||
Tool(
|
||||
name="echo",
|
||||
description="Echo back the provided input",
|
||||
|
|
@ -27,8 +53,7 @@ async def list_tools():
|
|||
},
|
||||
"required": ["message"]
|
||||
},
|
||||
)
|
||||
,
|
||||
),
|
||||
Tool(
|
||||
name="fine_tune",
|
||||
description="Start fine-tuning a base model using training files",
|
||||
|
|
@ -69,6 +94,8 @@ async def list_tools():
|
|||
},
|
||||
),
|
||||
]
|
||||
logger.info(f"Available tools: {[t.name for t in tools]}")
|
||||
return tools
|
||||
|
||||
|
||||
def _now() -> str:
|
||||
|
|
@ -93,13 +120,10 @@ def _resolve_model_path(model_path: str) -> str:
|
|||
|
||||
candidates = []
|
||||
|
||||
# Try relative to current working directory
|
||||
candidates.append(os.path.normpath(os.path.join(os.getcwd(), norm)))
|
||||
|
||||
# Try relative to model root
|
||||
candidates.append(os.path.normpath(os.path.join(_model_root(), os.path.basename(norm))))
|
||||
|
||||
# If it's a Windows-style absolute path, map to container /app by trimming common root
|
||||
if ":" in model_path or "\\" in model_path:
|
||||
p = PureWindowsPath(model_path)
|
||||
parts = [str(x) for x in p.parts]
|
||||
|
|
@ -117,7 +141,6 @@ def _resolve_model_path(model_path: str) -> str:
|
|||
|
||||
|
||||
def _resolve_model_file(model_path: str) -> tuple[str, str]:
|
||||
"""Return (model_dir, model_filename) for GPT4All."""
|
||||
resolved = _resolve_model_path(model_path)
|
||||
if os.path.isdir(resolved):
|
||||
for name in os.listdir(resolved):
|
||||
|
|
@ -127,9 +150,603 @@ def _resolve_model_file(model_path: str) -> tuple[str, str]:
|
|||
return os.path.dirname(resolved), os.path.basename(resolved)
|
||||
|
||||
|
||||
def _load_training_file(file_path: str) -> List[Dict[str, str]]:
|
||||
logger.info(f"Loading training file: {file_path}")
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"Training file not found: {file_path}")
|
||||
raise FileNotFoundError(f"Training file not found: {file_path}")
|
||||
|
||||
_, ext = os.path.splitext(file_path)
|
||||
ext = ext.lower()
|
||||
|
||||
try:
|
||||
if ext == '.json':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
logger.info(f"JSON file contains {len(data)} items")
|
||||
return data
|
||||
elif isinstance(data, dict):
|
||||
logger.info(f"JSON file is dict, extracting first array or values")
|
||||
for key, val in data.items():
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
return [data]
|
||||
return []
|
||||
|
||||
elif ext == '.csv':
|
||||
import csv
|
||||
pairs = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
pairs.append(dict(row))
|
||||
logger.info(f"CSV file contains {len(pairs)} rows")
|
||||
return pairs
|
||||
|
||||
elif ext == '.txt':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
pairs = [{'text': para.strip()} for para in content.split('\n\n') if para.strip()]
|
||||
logger.info(f"TXT file contains {len(pairs)} paragraphs")
|
||||
return pairs
|
||||
|
||||
elif ext == '.md':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
pairs = [{'text': para.strip()} for para in content.split('\n\n') if para.strip()]
|
||||
logger.info(f"MD file contains {len(pairs)} sections")
|
||||
return pairs
|
||||
|
||||
elif ext == '.pdf':
|
||||
try:
|
||||
import PyPDF2
|
||||
pairs = []
|
||||
with open(file_path, 'rb') as f:
|
||||
reader = PyPDF2.PdfReader(f)
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if text.strip():
|
||||
pairs.append({'text': text.strip(), 'page': page_num})
|
||||
logger.info(f"PDF file contains {len(pairs)} pages")
|
||||
return pairs
|
||||
except:
|
||||
pass
|
||||
|
||||
elif ext == '.docx':
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document(file_path)
|
||||
pairs = [{'text': para.text} for para in doc.paragraphs if para.text.strip()]
|
||||
logger.info(f"DOCX file contains {len(pairs)} paragraphs")
|
||||
return pairs
|
||||
except ImportError:
|
||||
logger.error("python-docx not available for DOCX parsing")
|
||||
raise ImportError("python-docx library not available")
|
||||
|
||||
else:
|
||||
logger.error(f"Unsupported file type: {ext}")
|
||||
raise ValueError(f"Unsupported file type: {ext}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load training file: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def _ensure_pair_extractor_model(base_model: str):
|
||||
if PAIR_EXTRACTOR.get("model") is not None:
|
||||
return PAIR_EXTRACTOR["tokenizer"], PAIR_EXTRACTOR["model"]
|
||||
|
||||
logger.info("Loading base model for pair extraction (prompt-based)")
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model,
|
||||
cache_dir=model_cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
cache_dir=model_cache_dir,
|
||||
quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
|
||||
device_map="auto",
|
||||
dtype=torch.float16,
|
||||
offload_dir='./offload_dir',
|
||||
)
|
||||
|
||||
PAIR_EXTRACTOR["tokenizer"] = tokenizer
|
||||
PAIR_EXTRACTOR["model"] = model
|
||||
logger.info("Pair extractor model loaded and cached")
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def _format_training_sample(sample: Any) -> str:
|
||||
try:
|
||||
if isinstance(sample, dict):
|
||||
parts = []
|
||||
for k, v in sample.items():
|
||||
if isinstance(v, str) and v.strip():
|
||||
parts.append(f"{k}: {v.strip()}")
|
||||
return " | ".join(parts) if parts else json.dumps(sample)
|
||||
if isinstance(sample, str):
|
||||
return sample.strip()
|
||||
return str(sample)
|
||||
except Exception:
|
||||
return str(sample)
|
||||
|
||||
|
||||
def _prompt_based_pair_extraction(training_data: List[Any], base_model: str) -> List[Tuple[str, str]]:
|
||||
import torch
|
||||
|
||||
tokenizer, model = _ensure_pair_extractor_model(base_model)
|
||||
|
||||
max_items = 12
|
||||
subset = training_data
|
||||
formatted = [f"{i+1}. {_format_training_sample(item)}" for i, item in enumerate(subset)]
|
||||
data_block = "\n".join(formatted)
|
||||
|
||||
example_pairs = [
|
||||
{"instruction": "Explain what a REST API is.", "response": "A REST API is an interface that uses HTTP methods..."},
|
||||
{"instruction": "Summarize the customer complaint.", "response": "Customer reports delayed shipment and requests refund."},
|
||||
]
|
||||
|
||||
system_prompt = (
|
||||
"You are a data extractor. Given a list of items, return a JSON array of training pairs. "
|
||||
"Each pair must have 'instruction' and 'response'. Keep answers concise. "
|
||||
"If content is incomplete, still produce best-effort pairs."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
"Examples of desired output:\n"
|
||||
f"{json.dumps(example_pairs, ensure_ascii=False, indent=2)}\n\n"
|
||||
"Now extract training pairs from the following items. Return ONLY a JSON array, no extra text.\n"
|
||||
f"Items:\n{data_block}"
|
||||
)
|
||||
|
||||
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
max_new_tokens = 512
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=0.2,
|
||||
top_p=0.9,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
json_start = decoded.find("[")
|
||||
if json_start == -1:
|
||||
logger.error("LLM extraction failed to produce JSON array (no '[' found)")
|
||||
return []
|
||||
|
||||
pairs: List[Tuple[str, str]] = []
|
||||
bracket_count = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
json_end = -1
|
||||
|
||||
for i in range(json_start, len(decoded)):
|
||||
char = decoded[i]
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == '\\':
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if not in_string:
|
||||
if char == '[':
|
||||
bracket_count += 1
|
||||
elif char == ']':
|
||||
bracket_count -= 1
|
||||
if bracket_count == 0:
|
||||
json_end = i
|
||||
break
|
||||
|
||||
if json_end == -1:
|
||||
logger.error("LLM extraction failed to find valid JSON array boundary")
|
||||
return []
|
||||
|
||||
try:
|
||||
json_text = decoded[json_start: json_end + 1]
|
||||
logger.debug(f"Extracted JSON text (first 200 chars): {json_text[:200]}")
|
||||
parsed = json.loads(json_text)
|
||||
for item in parsed:
|
||||
instr = str(item.get("instruction", "")).strip()
|
||||
resp = str(item.get("response", "")).strip()
|
||||
if instr and resp:
|
||||
pairs.append((instr, resp))
|
||||
logger.info(f"LLM extracted {len(pairs)} pairs via prompting")
|
||||
return pairs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse LLM-extracted JSON: {str(e)}")
|
||||
logger.debug(f"JSON text that failed to parse: {json_text if 'json_text' in locals() else 'N/A'}")
|
||||
return []
|
||||
|
||||
|
||||
def _extract_training_pairs(training_data: List[Any]) -> List[Tuple[str, str]]:
|
||||
logger.info(f"Extracting training pairs via LLM for {len(training_data)} items")
|
||||
|
||||
if not training_data:
|
||||
return []
|
||||
|
||||
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
pairs = _prompt_based_pair_extraction(training_data, base_model)
|
||||
if not pairs:
|
||||
logger.warning("LLM extraction failed; falling back to minimal heuristic")
|
||||
for item in training_data:
|
||||
text = _format_training_sample(item)
|
||||
if text:
|
||||
mid = max(1, len(text) // 2)
|
||||
pairs.append((text[:mid].strip(), text[mid:].strip() or text[:50]))
|
||||
|
||||
logger.info(f"Total pairs extracted: {len(pairs)}")
|
||||
if pairs:
|
||||
logger.debug(f"Sample pair: instruction='{pairs[0][0][:80]}...', response='{pairs[0][1][:80]}...'")
|
||||
return pairs
|
||||
|
||||
|
||||
async def _fine_tune_model_impl(
|
||||
training_files: List[str],
|
||||
hyperparams: Dict[str, Any],
|
||||
model_name: str,
|
||||
version: str,
|
||||
output_dir: str
|
||||
) -> Dict[str, Any]:
|
||||
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
logger.info(f"Starting fine-tune process with base model: {base_model}")
|
||||
|
||||
try:
|
||||
logger.info(f"Step 1: Loading {len(training_files)} training files")
|
||||
all_training_pairs = []
|
||||
|
||||
for file_path in training_files:
|
||||
try:
|
||||
training_data = _load_training_file(file_path)
|
||||
pairs = _extract_training_pairs(training_data)
|
||||
all_training_pairs.extend(pairs)
|
||||
logger.info(f"File {os.path.basename(file_path)}: {len(pairs)} pairs extracted")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not all_training_pairs:
|
||||
logger.error("No training pairs extracted from any files")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "no_training_pairs_extracted",
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
logger.info(f"Step 1 Complete: Total {len(all_training_pairs)} training pairs extracted")
|
||||
logger.info(f"Step 2: Loading base model and tokenizer: {base_model}")
|
||||
try:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from huggingface_hub import login, HfApi
|
||||
import torch
|
||||
|
||||
logger.info("Step 2a: Authenticating with HuggingFace...")
|
||||
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
|
||||
if hf_token:
|
||||
logger.info("HuggingFace token found in environment, logging in...")
|
||||
try:
|
||||
login(token=hf_token, write_permission=False)
|
||||
logger.info("Successfully authenticated with HuggingFace")
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Failed to authenticate with HuggingFace: {str(e)}")
|
||||
logger.error(f"Traceback:\n{error_details}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "huggingface_auth_failed",
|
||||
"details": str(e),
|
||||
"traceback": error_details,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
else:
|
||||
logger.warning("HF_TOKEN or HUGGINGFACE_TOKEN environment variable not found. Model access may be restricted.")
|
||||
|
||||
logger.info("Step 2b: Loading tokenizer...")
|
||||
try:
|
||||
logger.info(f"Tokenizer cache directory: {model_cache_dir}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model,
|
||||
cache_dir=model_cache_dir,
|
||||
local_files_only=False
|
||||
)
|
||||
logger.info("Tokenizer download started...")
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Tokenizer loaded successfully")
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Failed to load tokenizer: {str(e)}")
|
||||
logger.error(f"Traceback:\n{error_details}")
|
||||
raise
|
||||
|
||||
logger.info("Step 2b: Loading base model with 4-bit quantization...")
|
||||
try:
|
||||
logger.info(f"Model cache directory: {model_cache_dir}")
|
||||
logger.info(f"Starting model download for {base_model}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
cache_dir=model_cache_dir,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
),
|
||||
device_map="auto",
|
||||
dtype=torch.float16,
|
||||
)
|
||||
logger.info("Base model loaded successfully with 4-bit quantization")
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Failed to load base model: {str(e)}")
|
||||
logger.error(f"Traceback:\n{error_details}")
|
||||
raise
|
||||
|
||||
except ImportError as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Required HuggingFace transformers not available: {str(e)}")
|
||||
logger.error(f"Traceback:\n{error_details}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "transformers_not_available",
|
||||
"details": str(e),
|
||||
"traceback": error_details,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
logger.info(f"Step 2 Complete: Model and tokenizer loaded")
|
||||
logger.info(f"Step 3: Fine-tuning with LoRA adapters")
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer
|
||||
from transformers import TrainingArguments
|
||||
import json
|
||||
|
||||
training_data_file = os.path.join(output_dir, "training_data.jsonl")
|
||||
with open(training_data_file, "w", encoding="utf-8") as f:
|
||||
for prompt, response in all_training_pairs:
|
||||
training_pair = {
|
||||
"instruction": prompt,
|
||||
"input": "",
|
||||
"output": response,
|
||||
"text": f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n{response}<|im_end|>"
|
||||
}
|
||||
f.write(json.dumps(training_pair, ensure_ascii=False) + "\n")
|
||||
logger.info(f"Training data file created: {training_data_file}")
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("json", data_files=training_data_file)
|
||||
logger.info(f"Dataset loaded: {len(dataset['train'])} examples")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=int(hyperparams.get("lora_r", 64)),
|
||||
lora_alpha=int(hyperparams.get("lora_alpha", 16)),
|
||||
lora_dropout=float(hyperparams.get("lora_dropout", 0.05)),
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=os.path.join(output_dir, "checkpoints"),
|
||||
num_train_epochs=int(hyperparams.get("epochs", 3)),
|
||||
per_device_train_batch_size=int(hyperparams.get("batch_size", 6)),
|
||||
gradient_accumulation_steps=int(hyperparams.get("gradient_accumulation_steps", 3)),
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
optim="paged_adamw_8bit",
|
||||
max_grad_norm=0.0,
|
||||
logging_steps=10,
|
||||
save_strategy="epoch",
|
||||
ddp_find_unused_parameters=False,
|
||||
remove_unused_columns=False,
|
||||
dataloader_pin_memory=True,
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset["train"],
|
||||
peft_config=lora_config,
|
||||
args=training_args,
|
||||
)
|
||||
trainer.accelerator.scaler = None
|
||||
|
||||
logger.info("Starting LoRA fine-tuning...")
|
||||
trainer.train()
|
||||
logger.info("LoRA fine-tuning completed")
|
||||
adapter_dir = os.path.join(output_dir, "adapter")
|
||||
os.makedirs(adapter_dir, exist_ok=True)
|
||||
trainer.model.save_pretrained(adapter_dir)
|
||||
tokenizer.save_pretrained(adapter_dir)
|
||||
logger.info(f"LoRA adapter saved to {adapter_dir}")
|
||||
merge_dir = os.path.join(output_dir, "merged")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
logger.info(f"Merging LoRA adapters...")
|
||||
|
||||
from peft import PeftModel
|
||||
base_reload = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
cache_dir=model_cache_dir,
|
||||
dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
merged_model = PeftModel.from_pretrained(base_reload, adapter_dir)
|
||||
merged_model = merged_model.merge_and_unload()
|
||||
|
||||
merged_model.save_pretrained(merge_dir)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
logger.info(f"Merged model saved to {merge_dir}")
|
||||
logger.info("Cleaning up GPU memory...")
|
||||
del model
|
||||
del merged_model
|
||||
del base_reload
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.info(f"GPU memory freed: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB reserved")
|
||||
logger.info("GPU memory cleanup completed")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Required LoRA/SFT training libraries not available: {str(e)}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "training_libs_not_available",
|
||||
"details": str(e),
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
logger.info(f"Step 3 Complete: Fine-tuning and merging completed")
|
||||
logger.info(f"Step 4: Converting merged model to GGUF format")
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
merge_dir = os.path.join(output_dir, "merged")
|
||||
gguf_f16_path = os.path.join(output_dir, "model-f16.gguf")
|
||||
convert_script = None
|
||||
for candidate in [
|
||||
os.path.join(os.getcwd(), "llama.cpp", "convert_hf_to_gguf.py"),
|
||||
os.path.join(os.getcwd(), "notebooks", "build", "llama.cpp", "convert_hf_to_gguf.py"),
|
||||
"/app/llama.cpp/convert_hf_to_gguf.py",
|
||||
"/home/llama.cpp/convert_hf_to_gguf.py",
|
||||
]:
|
||||
if os.path.exists(candidate):
|
||||
convert_script = candidate
|
||||
break
|
||||
|
||||
if not convert_script:
|
||||
logger.warning("convert_hf_to_gguf.py not found, trying with python -m")
|
||||
convert_script = None
|
||||
|
||||
if convert_script:
|
||||
logger.info(f"Converting with script: {convert_script}")
|
||||
result = subprocess.run([
|
||||
"python", convert_script,
|
||||
merge_dir,
|
||||
"--outfile", gguf_f16_path,
|
||||
"--outtype", "f16"
|
||||
], capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Conversion script failed with return code {result.returncode}")
|
||||
logger.error(f"STDOUT: {result.stdout}")
|
||||
logger.error(f"STDERR: {result.stderr}")
|
||||
raise RuntimeError(f"GGUF conversion failed: {result.stderr or 'unknown error'}")
|
||||
else:
|
||||
logger.info(f"GGUF conversion completed: {gguf_f16_path}")
|
||||
else:
|
||||
logger.info("Attempting direct conversion with transformers")
|
||||
from transformers import AutoModel
|
||||
model = AutoModelForCausalLM.from_pretrained(merge_dir, torch_dtype=torch.float16)
|
||||
logger.warning("Direct GGUF conversion not available, using float16 checkpoint instead")
|
||||
gguf_f16_path = merge_dir
|
||||
|
||||
logger.info(f"GGUF conversion completed: {gguf_f16_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GGUF conversion failed: {str(e)}", exc_info=True)
|
||||
gguf_f16_path = os.path.join(output_dir, "merged")
|
||||
logger.warning("Using merged model directly without GGUF conversion")
|
||||
logger.info(f"Step 4 Complete: GGUF conversion completed")
|
||||
logger.info(f"Step 5: Verifying model format")
|
||||
quantized_path = gguf_f16_path
|
||||
|
||||
if os.path.isfile(gguf_f16_path):
|
||||
logger.info(f"GGUF file verified: {gguf_f16_path}")
|
||||
logger.info(f"GGUF model size: {os.path.getsize(gguf_f16_path) / (1024**3):.2f} GB")
|
||||
quantized_path = gguf_f16_path
|
||||
else:
|
||||
logger.warning(f"No GGUF file found; using merged model folder")
|
||||
quantized_path = gguf_f16_path
|
||||
logger.info(f"Step 6: Creating metadata and finalizing")
|
||||
quantized_file = quantized_path
|
||||
if os.path.isdir(quantized_file):
|
||||
logger.warning(f"Quantized path is a directory ({quantized_file}); searching for .gguf inside")
|
||||
ggufs = [f for f in os.listdir(quantized_file) if f.lower().endswith('.gguf')]
|
||||
if ggufs:
|
||||
quantized_file = os.path.join(quantized_file, ggufs[0])
|
||||
logger.info(f"Found GGUF inside directory: {quantized_file}")
|
||||
else:
|
||||
logger.error("No .gguf file found inside quantized directory; using merged folder as fallback")
|
||||
quantized_file = quantized_path
|
||||
final_model_path = os.path.join(_model_root(), f"{model_name}-{version}.gguf")
|
||||
if os.path.isfile(quantized_file) and quantized_file != final_model_path:
|
||||
import shutil
|
||||
logger.info(f"Copying quantized model to final location: {final_model_path}")
|
||||
shutil.copy2(quantized_file, final_model_path)
|
||||
logger.info(f"Final model saved to: {final_model_path}")
|
||||
else:
|
||||
final_model_path = quantized_file
|
||||
|
||||
metadata = {
|
||||
"status": "completed",
|
||||
"base_model": base_model,
|
||||
"name": model_name,
|
||||
"version": version,
|
||||
"model_path": final_model_path,
|
||||
"path": final_model_path,
|
||||
"output_dir": output_dir,
|
||||
"training_files_count": len(training_files),
|
||||
"training_pairs_count": len(all_training_pairs),
|
||||
"hyperparams": hyperparams,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
try:
|
||||
with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
logger.info(f"Metadata saved to {output_dir}/metadata.json")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save metadata: {str(e)}", exc_info=True)
|
||||
logger.info("Performing final GPU memory cleanup...")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
logger.info(f"Final GPU memory state: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB allocated")
|
||||
|
||||
logger.info(f"Step 6 Complete: All steps completed successfully")
|
||||
logger.info(f"Final model available at: {final_model_path}")
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Fine-tune process failed: {str(e)}")
|
||||
logger.error(f"Error type: {type(e).__name__}")
|
||||
logger.error(f"Full traceback:\n{error_details}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e) or "Unknown error occurred",
|
||||
"error_type": type(e).__name__,
|
||||
"traceback": error_details,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
|
||||
async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
||||
logger.info(f"Executing tool: {name}")
|
||||
logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
if name == "echo":
|
||||
return {"status": "ok", "received": arguments, "timestamp": _now()}
|
||||
logger.info(f"Echo tool called with message: {arguments.get('message')}")
|
||||
result = {"status": "ok", "received": arguments, "timestamp": _now()}
|
||||
logger.info(f"Echo tool completed successfully")
|
||||
return result
|
||||
|
||||
if name == "fine_tune":
|
||||
base_model = arguments.get("base_model")
|
||||
|
|
@ -138,47 +755,59 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
model_name = arguments.get("name") or "model"
|
||||
version = arguments.get("version") or "v1"
|
||||
|
||||
logger.info(f"Fine-tune started: model={model_name}, version={version}")
|
||||
logger.info(f"Training files count: {len(training_files)}")
|
||||
logger.debug(f"Training files: {training_files}")
|
||||
logger.debug(f"Hyperparameters: {hyperparams}")
|
||||
|
||||
model_root = _model_root()
|
||||
os.makedirs(model_root, exist_ok=True)
|
||||
logger.debug(f"Model root directory: {model_root}")
|
||||
|
||||
safe_name = _safe_dir_name(model_name)
|
||||
safe_version = _safe_dir_name(version)
|
||||
output_dir = os.path.join(model_root, f"{safe_name}-{safe_version}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
metadata = {
|
||||
"status": "completed",
|
||||
"base_model": base_model,
|
||||
"training_files": training_files,
|
||||
"hyperparams": hyperparams,
|
||||
"name": model_name,
|
||||
"version": version,
|
||||
"model_path": output_dir,
|
||||
logger.info(f"Output directory created: {output_dir}")
|
||||
try:
|
||||
result = await _fine_tune_model_impl(training_files, hyperparams, model_name, version, output_dir)
|
||||
logger.info(f"Fine-tune result: {result.get('status')}")
|
||||
return result
|
||||
except Exception as e:
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Fine-tune tool execution failed: {str(e)}")
|
||||
logger.error(f"Full traceback:\n{error_details}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e) or "Unknown error in fine-tune execution",
|
||||
"error_type": type(e).__name__,
|
||||
"traceback": error_details,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
try:
|
||||
with open(os.path.join(output_dir, "metadata.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return metadata
|
||||
|
||||
if name == "load_model":
|
||||
model_path = arguments.get("model_path")
|
||||
logger.info(f"Loading model: {model_path}")
|
||||
|
||||
if not model_path:
|
||||
logger.error("model_path_required error: no model path provided")
|
||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||
|
||||
model_path = _resolve_model_path(model_path)
|
||||
logger.debug(f"Resolved model path: {model_path}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model not found at: {model_path}")
|
||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
|
||||
model_dir, model_file = _resolve_model_file(model_path)
|
||||
logger.debug(f"Model directory: {model_dir}, model file: {model_file}")
|
||||
|
||||
if not model_file:
|
||||
logger.error(f"No GGUF file found in model directory: {model_dir}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "model_file_not_found",
|
||||
|
|
@ -186,13 +815,26 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
"timestamp": _now(),
|
||||
}
|
||||
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='gpu')
|
||||
logger.info(f"Initializing GPT4All model: {model_file}")
|
||||
try:
|
||||
logger.info("Attempting to load model on GPU (cuda)...")
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='cuda')
|
||||
logger.info(f"Model loaded successfully on GPU")
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU initialization failed: {str(e)}, falling back to CPU")
|
||||
logger.info("Loading model on CPU...")
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='cpu')
|
||||
logger.info(f"Model loaded successfully on CPU")
|
||||
|
||||
LOADED_MODELS[model_path] = {
|
||||
"loaded_at": _now(),
|
||||
"model": model,
|
||||
"model_dir": model_dir,
|
||||
"model_file": model_file,
|
||||
}
|
||||
logger.info(f"Model loaded successfully: {model_path}")
|
||||
logger.info(f"Total loaded models in memory: {len(LOADED_MODELS)}")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"model_path": model_path,
|
||||
|
|
@ -202,6 +844,7 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
"timestamp": _now(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
|
|
@ -215,33 +858,57 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
prompt = arguments.get("prompt") or ""
|
||||
options = arguments.get("options") or {}
|
||||
|
||||
logger.info(f"Inference request: model={model_path}")
|
||||
logger.debug(f"Prompt length: {len(prompt)} characters")
|
||||
logger.debug(f"Inference options: {options}")
|
||||
|
||||
if not model_path:
|
||||
logger.error("model_path_required error: no model path provided")
|
||||
return {"status": "failed", "error": "model_path_required", "timestamp": _now()}
|
||||
|
||||
model_path = _resolve_model_path(model_path)
|
||||
logger.debug(f"Resolved model path: {model_path}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"Model not found at: {model_path}")
|
||||
return {"status": "failed", "error": "model_not_found", "model_path": model_path, "timestamp": _now()}
|
||||
|
||||
try:
|
||||
if model_path not in LOADED_MODELS or "model" not in LOADED_MODELS[model_path]:
|
||||
logger.info(f"Model not in memory, loading: {model_path}")
|
||||
from gpt4all import GPT4All
|
||||
|
||||
model_dir, model_file = _resolve_model_file(model_path)
|
||||
logger.debug(f"Model directory: {model_dir}, model file: {model_file}")
|
||||
|
||||
if not model_file:
|
||||
logger.error(f"No GGUF file found in model directory: {model_dir}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": "model_file_not_found",
|
||||
"model_path": model_path,
|
||||
"timestamp": _now(),
|
||||
}
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False)
|
||||
logger.info(f"Initializing GPT4All model: {model_file}")
|
||||
try:
|
||||
logger.info("Attempting to load model on GPU (cuda)...")
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='cuda')
|
||||
logger.info(f"Model loaded successfully on GPU for inference")
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU initialization failed: {str(e)}, falling back to CPU")
|
||||
logger.info("Loading model on CPU...")
|
||||
model = GPT4All(model_file, model_path=model_dir, allow_download=False, device='cpu')
|
||||
logger.info(f"Model loaded successfully on CPU for inference")
|
||||
|
||||
LOADED_MODELS[model_path] = {
|
||||
"loaded_at": _now(),
|
||||
"model": model,
|
||||
"model_dir": model_dir,
|
||||
"model_file": model_file,
|
||||
}
|
||||
logger.info(f"Model loaded for inference")
|
||||
else:
|
||||
logger.debug(f"Using cached model: {model_path}")
|
||||
|
||||
model = LOADED_MODELS[model_path]["model"]
|
||||
max_tokens = int(options.get("max_tokens", 256))
|
||||
|
|
@ -249,6 +916,9 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
top_p = float(options.get("top_p", 0.95))
|
||||
top_k = int(options.get("top_k", 40))
|
||||
|
||||
logger.info(f"Running inference with max_tokens={max_tokens}, temperature={temp}, top_p={top_p}, top_k={top_k}")
|
||||
logger.debug(f"Full inference parameters: max_tokens={max_tokens}, temp={temp}, top_p={top_p}, top_k={top_k}")
|
||||
|
||||
response_text = model.generate(
|
||||
prompt,
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -257,6 +927,9 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
top_k=top_k,
|
||||
)
|
||||
|
||||
logger.info(f"Inference completed successfully. Response length: {len(response_text)} characters")
|
||||
logger.debug(f"Response preview: {response_text[:100]}..." if len(response_text) > 100 else f"Response: {response_text}")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"model_path": model_path,
|
||||
|
|
@ -270,6 +943,7 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
"timestamp": _now(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Inference failed: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
|
|
@ -283,37 +957,61 @@ async def _run_tool_http(name: str, arguments: dict) -> Dict[str, Any]:
|
|||
|
||||
@app.call_tool()
|
||||
async def call_tool(name: str, arguments: dict):
|
||||
logger.info(f"MCP call_tool: {name}")
|
||||
result = await _run_tool_http(name, arguments)
|
||||
logger.debug(f"MCP call_tool result for {name}: {result}")
|
||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||
|
||||
|
||||
async def handle_execute(request: web.Request) -> web.Response:
|
||||
logger.info("HTTP /execute request received")
|
||||
execution_id = None
|
||||
try:
|
||||
payload = await request.json()
|
||||
tool = payload.get("tool")
|
||||
arguments = payload.get("arguments", {})
|
||||
execution_id = arguments.get("execution_id") or arguments.get("name", "unknown")
|
||||
logger.info(f"HTTP execute: tool={tool}, execution_id={execution_id}")
|
||||
logger.debug(f"HTTP execute arguments: {arguments}")
|
||||
|
||||
if not tool:
|
||||
logger.error("Missing 'tool' field in request")
|
||||
return web.json_response(
|
||||
{"error": "Missing 'tool' field"}, status=400
|
||||
)
|
||||
|
||||
logger.info(f"Calling _run_tool_http for {tool}...")
|
||||
result = await _run_tool_http(tool, arguments)
|
||||
logger.info(f"HTTP execute completed for {tool} with status={result.get('status')}")
|
||||
logger.debug(f"HTTP execute result: {result}")
|
||||
return web.json_response(result)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON"}, status=400)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in request: {str(e)}")
|
||||
return web.json_response({"error": f"Invalid JSON: {str(e)}"}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
error_msg = str(e) if str(e) else f"Unknown error: {type(e).__name__}"
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error(f"Unexpected error in /execute (execution_id={execution_id}): {error_msg}")
|
||||
logger.error(f"Traceback:\n{error_traceback}")
|
||||
return web.json_response({
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"error_type": type(e).__name__,
|
||||
"traceback": error_traceback,
|
||||
"execution_id": execution_id,
|
||||
}, status=500)
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
logger.debug("Health check requested")
|
||||
return web.json_response({"status": "healthy"})
|
||||
|
||||
|
||||
async def run_http_server():
|
||||
host = os.getenv("MCP_HTTP_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("MCP_HTTP_PORT", "8001"))
|
||||
logger.info(f"Starting HTTP server on {host}:{port}")
|
||||
|
||||
app_http = web.Application()
|
||||
app_http.router.add_post("/execute", handle_execute)
|
||||
|
|
@ -324,9 +1022,17 @@ async def run_http_server():
|
|||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
|
||||
logger.info(f"HTTP server running on {host}:{port}")
|
||||
print(f"HTTP server running on {host}:{port}", file=sys.stderr)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting MCP Server...")
|
||||
try:
|
||||
asyncio.run(run_http_server())
|
||||
except KeyboardInterrupt:
|
||||
logger.info("MCP Server interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"MCP Server error: {str(e)}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue