{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mediation analysis with duration data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates mediation analysis when the\n",
    "mediator and outcome are duration variables, modeled\n",
    "using proportional hazards regression.  These examples\n",
    "are based on simulated data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import statsmodels.api as sm\n",
    "from statsmodels.stats.mediation import Mediation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make the notebook reproducible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [],
   "source": [
    "np.random.seed(3424)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Specify a sample size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [],
   "source": [
    "n = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate an exposure variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
 
 
 
 
    },
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "exp = np.random.normal(size=n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate a mediator variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
 
 
 
 
    },
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def gen_mediator():\n",
    "    mn = np.exp(exp)\n",
    "    mtime0 = -mn * np.log(np.random.uniform(size=n))\n",
    "    ctime = -2 * mn * np.log(np.random.uniform(size=n))\n",
    "    mstatus = (ctime >= mtime0).astype(int)\n",
    "    mtime = np.where(mtime0 <= ctime, mtime0, ctime)\n",
    "    return mtime0, mtime, mstatus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate an outcome variable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
 
 
 
 
    },
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def gen_outcome(otype, mtime0):\n",
    "    if otype == \"full\":\n",
    "        lp = 0.5 * mtime0\n",
    "    elif otype == \"no\":\n",
    "        lp = exp\n",
    "    else:\n",
    "        lp = exp + mtime0\n",
    "    mn = np.exp(-lp)\n",
    "    ytime0 = -mn * np.log(np.random.uniform(size=n))\n",
    "    ctime = -2 * mn * np.log(np.random.uniform(size=n))\n",
    "    ystatus = (ctime >= ytime0).astype(int)\n",
    "    ytime = np.where(ytime0 <= ctime, ytime0, ctime)\n",
    "    return ytime, ystatus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build a dataframe containing all the relevant variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
 
 
 
 
    },
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def build_df(ytime, ystatus, mtime0, mtime, mstatus):\n",
    "    df = pd.DataFrame(\n",
    "        {\n",
    "            \"ytime\": ytime,\n",
    "            \"ystatus\": ystatus,\n",
    "            \"mtime\": mtime,\n",
    "            \"mstatus\": mstatus,\n",
    "            \"exp\": exp,\n",
    "        }\n",
    "    )\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the full simulation and analysis, under a particular\n",
    "population structure of mediation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
 
 
 
 
    },
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "def run(otype):\n",
    "\n",
    "    mtime0, mtime, mstatus = gen_mediator()\n",
    "    ytime, ystatus = gen_outcome(otype, mtime0)\n",
    "    df = build_df(ytime, ystatus, mtime0, mtime, mstatus)\n",
    "\n",
    "    outcome_model = sm.PHReg.from_formula(\n",
    "        \"ytime ~ exp + mtime\", status=\"ystatus\", data=df\n",
    "    )\n",
    "    mediator_model = sm.PHReg.from_formula(\"mtime ~ exp\", status=\"mstatus\", data=df)\n",
    "\n",
    "    med = Mediation(\n",
    "        outcome_model,\n",
    "        mediator_model,\n",
    "        \"exp\",\n",
    "        \"mtime\",\n",
    "        outcome_predict_kwargs={\"pred_only\": True},\n",
    "    )\n",
    "    med_result = med.fit(n_rep=20)\n",
    "    print(med_result.summary())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the example with full mediation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          Estimate  Lower CI bound  Upper CI bound  P-value\n",
      "ACME (control)            0.742427        0.643339        0.862745      0.0\n",
      "ACME (treated)            0.742427        0.643339        0.862745      0.0\n",
      "ADE (control)             0.073017       -0.016189        0.155321      0.1\n",
      "ADE (treated)             0.073017       -0.016189        0.155321      0.1\n",
      "Total effect              0.815444        0.675214        0.919580      0.0\n",
      "Prop. mediated (control)  0.912695        0.814965        1.025747      0.0\n",
      "Prop. mediated (treated)  0.912695        0.814965        1.025747      0.0\n",
      "ACME (average)            0.742427        0.643339        0.862745      0.0\n",
      "ADE (average)             0.073017       -0.016189        0.155321      0.1\n",
      "Prop. mediated (average)  0.912695        0.814965        1.025747      0.0\n"
     ]
    }
   ],
   "source": [
    "run(\"full\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the example with partial mediation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          Estimate  Lower CI bound  Upper CI bound  P-value\n",
      "ACME (control)            0.987067        0.801560        1.192019      0.0\n",
      "ACME (treated)            0.987067        0.801560        1.192019      0.0\n",
      "ADE (control)             1.071734        0.964214        1.150352      0.0\n",
      "ADE (treated)             1.071734        0.964214        1.150352      0.0\n",
      "Total effect              2.058801        1.862231        2.288170      0.0\n",
      "Prop. mediated (control)  0.481807        0.417501        0.533773      0.0\n",
      "Prop. mediated (treated)  0.481807        0.417501        0.533773      0.0\n",
      "ACME (average)            0.987067        0.801560        1.192019      0.0\n",
      "ADE (average)             1.071734        0.964214        1.150352      0.0\n",
      "Prop. mediated (average)  0.481807        0.417501        0.533773      0.0\n"
     ]
    }
   ],
   "source": [
    "run(\"partial\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the example with no mediation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
 
 
 
 
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                          Estimate  Lower CI bound  Upper CI bound  P-value\n",
      "ACME (control)            0.010200       -0.039434        0.065176      1.0\n",
      "ACME (treated)            0.010200       -0.039434        0.065176      1.0\n",
      "ADE (control)             0.902295        0.824526        0.984934      0.0\n",
      "ADE (treated)             0.902295        0.824526        0.984934      0.0\n",
      "Total effect              0.912495        0.834728        1.009958      0.0\n",
      "Prop. mediated (control)  0.003763       -0.044186        0.065520      1.0\n",
      "Prop. mediated (treated)  0.003763       -0.044186        0.065520      1.0\n",
      "ACME (average)            0.010200       -0.039434        0.065176      1.0\n",
      "ADE (average)             0.902295        0.824526        0.984934      0.0\n",
      "Prop. mediated (average)  0.003763       -0.044186        0.065520      1.0\n"
     ]
    }
   ],
   "source": [
    "run(\"no\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
