Submitting Jobs with SLURM

At this stage, you have:

Now, you need to run your analysis efficiently on the VSC cluster using SLURM.
In this section, you will:


1. Write the Python Script

You need a Python script that will:

  • Read ICD codes and disease names from your input data (e.g., a Parquet file)
  • Prompt the LLM for each disease to generate symptom lists
  • Save the results in JSONL and Parquet formats for easy downstream analysis

Use the template below as a starting point (customize paths, model name, and prompt as needed):

import os
import pandas as pd
import requests
import time
import logging
from fastparquet import write as fp_write

# ----------- CONFIGURATION -----------

# set the model name you want to use
MODEL_NAME = "your-llm-model-name"

# update these paths for your environment
PARQUET_PATH = "/path/to/your/input_data.parquet"
RESULTS_DIR = "/path/to/your/results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# output files
JSON_OUTPUT = os.path.join(RESULTS_DIR, f"symptom_results_{MODEL_NAME}.json")
PARQUET_OUTPUT = os.path.join(RESULTS_DIR, f"symptom_results_{MODEL_NAME}.parquet")
LOG_FILE = os.path.join(RESULTS_DIR, f"extract_symptoms_{MODEL_NAME}.log")

# LLM API endpoint
LLM_API_URL = "http://localhost:37712/api/generate"

# list of models and their settings
LLM_MODELS = [MODEL_NAME]
LLM_MODEL_SETTINGS = {
    "your-llm-model-name": {
        "temperature": 0.5,
        "max_tokens": 200,
        "top_k": 3,
        "top_n": 1
    }
}

logging.basicConfig(
    filename=LOG_FILE,
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# ----------- Environment Check -----------
def check_environment():
    if not os.path.exists(PARQUET_PATH):
        raise FileNotFoundError(f"Input Parquet file not found at {PARQUET_PATH}")
    if not os.path.exists(RESULTS_DIR):
        os.makedirs(RESULTS_DIR)
        logging.info(f"Created results directory at {RESULTS_DIR}")

# ----------- Data Loading -----------
def load_data(parquet_path):
    import duckdb
    conn = duckdb.connect()
    query = """
        SELECT DISTINCT icd_code, long_title
        FROM parquet_scan('{}')
        WHERE icd_code IS NOT NULL AND long_title IS NOT NULL
    """.format(parquet_path)
    df = conn.execute(query).fetch_df()
    conn.close()
    return df

# ----------- Prompt Creation -----------
def create_prompt(icd_code, description):
    return (
        f"Given the ICD-9CM code {icd_code} with description \"{description}\", "
        f"identify the 10 most common symptoms typically associated with this condition. "
        f"Return the symptoms as a comma-separated list, and only include the symptom names."
    )

# ----------- LLM Query Function -----------
def query_llm(prompt, model_name, temperature=0.5, max_tokens=150, top_k=3, top_n=1):
    payload = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_k": top_k,
        "top_n": top_n,
        "stream": False
    }
    try:
        response = requests.post(LLM_API_URL, json=payload, timeout=60)
        response.raise_for_status()
        return response.json().get("response", "").strip()
    except Exception as e:
        logging.error(f"LLM request failed for {model_name}: {e}")
        return ""

# ----------- Main Extraction -----------
def extract_symptoms(parquet_path):
    df = load_data(parquet_path)
    results = []

    for model_name in LLM_MODELS:
        logging.info(f"Processing model: {model_name}")
        params = LLM_MODEL_SETTINGS[model_name]
        for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Model {model_name}"):
            icd_code = row["icd_code"]
            description = row["long_title"]
            prompt = create_prompt(icd_code, description)
            start = time.time()
            response = query_llm(
                prompt, model_name,
                temperature=params["temperature"],
                max_tokens=params["max_tokens"],
                top_k=params["top_k"],
                top_n=params["top_n"]
            )
            duration = time.time() - start
            results.append({
                "icd_code": icd_code,
                "title": description,
                "model": model_name,
                "response": response,
                "duration": duration
            })

    return pd.DataFrame(results)

# ----------- Main Execution -----------
if __name__ == "__main__":
    try:
        check_environment()
        logging.info("Starting symptom extraction process...")
        start_time = time.time()

        symptom_df = extract_symptoms(PARQUET_PATH)
        symptom_df.to_json(JSON_OUTPUT, orient="records", indent=2)
        fp_write(PARQUET_OUTPUT, symptom_df, write_index=False)

        elapsed = time.time() - start_time
        logging.info(f"Extraction completed in {elapsed:.2f} seconds. Results saved to {RESULTS_DIR}")
        print(f"Extraction completed in {elapsed:.2f} seconds.")

    except Exception as e:
        logging.error(f"Fatal error: {e}")
        print(f"Fatal error: {e}")
        raise

2. Prepare the SLURM Script

Write a SLURM script (for example, llm_symptoms.slurm) that:

  • Requests the required resources on the cluster
  • Sets up the LLM server and environment
  • Runs your Python script

Template below is a guide:


#!/bin/bash -l

#SBATCH --job-name=ollama_icd9cm_symptoms
#SBATCH --cluster=wice
#SBATCH --partition=gpu_h100
#SBATCH --gpus-per-node=1
#SBATCH --time=02:00:00
#SBATCH --mem=128G
#SBATCH --account=xxxxx
#SBATCH --output=/data/leuven/377/vsc37712/results/icd_9cm_symptoms_deepseek70b%j.out

# Set environment variables for Ollama and model storage
export OLLAMA_MODELS=/path/to/your/llm_models
export OLLAMA_HOST=127.0.0.1:37712
export OLLAMA_PORT=`vscXXXXX`

# Move to the data/code directory
cd /path/to/your/working_directory

# Activate the Python conda environment
source /path/to/your/miniconda3/bin/activate your_env_name

# Set GPU usage
export CUDA_VISIBLE_DEVICES=0

# Start Ollama server (logging for debugging)
/path/to/your/llm_server/bin/llm_server_command serve > /path/to/your/results/llm_server_%j.log 2>&1 &

# Wait for Ollama server to be ready (simple loop)
for i in {1..15}; do
    if curl -s http://127.0.0.1:37712/api/tags > /dev/null; then
        echo "Ollama server is up!"
        break
    fi
    sleep 2
done

# Double-check server is up before proceeding
if ! curl -s http://127.0.0.1:xxxx/api/tags > /dev/null; then
    echo "Ollama server did not start correctly." >&2
    exit 1
fi

# Run the Python script
python /path/to/your/python_scripts/deepseek70b.py

echo "SLURM job finished at $(date)"

3. Submit and Monitor Your Job

Submit your SLURM job from the terminal:

sbatch run_llm_symptom_extraction.slurm

Monitor your job status:

squeue -u $USER

Check your logs and results in the output directory you specified in your slurm scripts.


4. Example: Output Result Snapshot

A typical output from the LLM will look like this:

[
  {
    "code": "250.00",
    "disease": "Diabetes mellitus without mention of complication",
    "symptoms": [
      "increased thirst",
      "frequent urination",
      "blurred vision",
      "fatigue",
      "slow-healing sores",
      "unintended weight loss",
      "hunger",
      "dry mouth",
      "headaches",
      "recurrent infections"
    ]
  },
  {
    "code": "401.9",
    "disease": "Essential hypertension, unspecified",
    "symptoms": [
      "headache",
      "dizziness",
      "nosebleeds",
      "shortness of breath",
      "chest pain",
      "blurred vision",
      "fatigue",
      "nausea",
      "palpitations",
      "confusion"
    ]
  }
]