{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "#Dict approach\n",
    "import glob\n",
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "#from typing import Any, Optional\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import h5py\n",
    "from sklearn.metrics import r2_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load genome loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load genome loader for salk cluster\n",
    "sys.path.append('/iblm/netapp/home/jjaureguy/genome_loader/genome-loader')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load genome loader for runAI clusterMJ\n",
    "sys.path.append('/home/jovyan/home/jjaureguy/genome_loader/genome-loader')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import genome_loader.write_h5\n",
    "import genome_loader.encode_data\n",
    "import genome_loader.get_data\n",
    "import genome_loader.get_encoded\n",
    "import genome_loader.load_data\n",
    "import genome_loader.load_h5\n",
    "\n",
    "from genome_loader.write_h5 import write_encoded_genome\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load in dataframes for Train/valid/test on runAI cluster\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#OHE for AGCT\n",
    "one_hot_enc_genome = h5py.File('/home/jovyan/data4/jjaureguy/out_dir/genome_onehot.h5','r')\n",
    "frag_tn5_train_df = pd.read_csv('/home/jovyan/data4/jjaureguy/jupyter/data_frames/runai_df/train_frag_tn5_h5_df_runai.txt',sep='\\t')\n",
    "frag_tn5_valid_df = pd.read_csv('/home/jovyan/data4/jjaureguy/jupyter/data_frames/runai_df/valid_frag_tn5_h5_df_runai.txt',sep='\\t')\n",
    "frag_tn5_test_df = pd.read_csv('/home/jovyan/data4/jjaureguy/jupyter/data_frames/runai_df/test_frag_tn5_h5_df_runai.txt',sep='\\t')\n",
    "\n",
    "\n",
    "\n",
    "train_bed_file_df = pd.read_csv('/home/jovyan/data4/UCSC_browser_lab_data/jjaureguy/ml_proj/bed_files/pos_training_regression/fikt_3_train_df_filtered_500_final_pos.txt', names=['chrom','start','end'], sep='\\t')\n",
    "valid_df = pd.read_csv('/home/jovyan/data4/UCSC_browser_lab_data/jjaureguy/ml_proj/bed_files/pos_training_regression/fikt_3_valid_df_filtered_500_final_pos.txt', sep='\\t', names = ['chrom', 'start','end'])\n",
    "test_bed_file_df = pd.read_csv('/home/jovyan/data4/UCSC_browser_lab_data/jjaureguy/ml_proj/bed_files/pos_training_regression/fikt_3_test_df_filtered_500_final_pos.txt', names=['chrom','start','end'], sep='\\t')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "fikt_3 = frag_tn5_train_df[frag_tn5_train_df['patient_id']=='HPSI0114i-fikt_3']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "fikt_3 = fikt_3.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>patient_id</th>\n",
       "      <th>diff_frag_h5_path</th>\n",
       "      <th>int_gamma_diff_frag_h5_path</th>\n",
       "      <th>salm_int_gamma_diff_frag_h5_path</th>\n",
       "      <th>salm_diff_frag_h5_path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>HPSI0114i-fikt_3</td>\n",
       "      <td>/home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...</td>\n",
       "      <td>/home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...</td>\n",
       "      <td>/home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...</td>\n",
       "      <td>/home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         patient_id                                  diff_frag_h5_path  \\\n",
       "0  HPSI0114i-fikt_3  /home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...   \n",
       "\n",
       "                         int_gamma_diff_frag_h5_path  \\\n",
       "0  /home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...   \n",
       "\n",
       "                    salm_int_gamma_diff_frag_h5_path  \\\n",
       "0  /home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...   \n",
       "\n",
       "                              salm_diff_frag_h5_path  \n",
       "0  /home/jovyan/data3/jjaureguy/PRJEB18997/10_gen...  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fikt_3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataframe"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dataset Custom Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "from enum import Enum\n",
    "from pathlib import Path\n",
    "from typing import Any, Optional\n",
    "from torch import nn\n",
    "import pytorch_lightning as pl\n",
    "import h5py\n",
    "from tqdm.notebook import tqdm\n",
    "#from .transforms import KmerShuffle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DataSet Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py \n",
    "import random\n",
    "from collections import defaultdict\n",
    "\n",
    "class Dataset_class(Dataset):\n",
    "    \"\"\"Dataset_class Dataset.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    one_hot_enc_genome object: pass the H5py file object. One-hot encoded genome `.h5` file. e.g. from genome-loader.\n",
    "    unified_bed_file_df : pass the df object. DF of bed file containing unified peak set of (1st 5 patients). Chrom, start, end.  \n",
    "    frag_tn5_h5_df: pass the df object. DF containing H5py paths for Tn5 read counts. Patient ID, treatment type, H5py path column containg Tn5 read counts\n",
    "    \"\"\"\n",
    "    # class constructor(initialize ohe, bed file, frag df file)\n",
    "    def __init__(self, one_hot_enc_genome, unified_bed_file_df, frag_tn5_h5_df,bin_size, dtype = torch.float32):\n",
    "        super().__init__()\n",
    "        self.bin_size = bin_size\n",
    "        self.one_hot_enc_genome = one_hot_enc_genome\n",
    "        self.unified_bed_file_df = unified_bed_file_df\n",
    "        self.frag_tn5_h5_df_1 = frag_tn5_h5_df\n",
    "        self.frag_tn5_h5_df_2 = frag_tn5_h5_df\n",
    "        self.frag_tn5_h5_df_3 = frag_tn5_h5_df\n",
    "        self.frag_tn5_h5_df_4 = frag_tn5_h5_df\n",
    "        self.dtype = dtype\n",
    "        # Empty dictionary for h5 file paths\n",
    "        self.h5_files = {}\n",
    "        #self.d = defaultdict(list)\n",
    "\n",
    "    # Takes in h5 path from Tn5 from frag_tn5_h5_df column containing h5 paths of Tn5 counts\n",
    "    def get_h5(self, h5_path):\n",
    "        # checks if h5 path is already in dictionary of h5py file paths\n",
    "        if h5_path in self.h5_files:\n",
    "            # returns them if there already there\n",
    "            return self.h5_files[h5_path]\n",
    "        else:\n",
    "            # Sets the path as a key, and the value as the read file\n",
    "            self.h5_files[h5_path] = h5py.File(h5_path, \"r\")\n",
    "            \n",
    "    \"\"\"Function: selects random midpoint from X features(one hot encoded ref genome peak range)\n",
    "          1) pick random point in range_num(list) and set to variable\n",
    "    \"\"\"\n",
    "    def get_mid_point(self, start, end):\n",
    "        #np.random.seed(0)\n",
    "        mid_point = int((start+end)/2)\n",
    "        #mid_point = np.random.choice(mid_point, 1, replace=True)\n",
    "        #print('mid pt selection random', (mid_point))\n",
    "        #mid_point = mid_point.item()\n",
    "        return mid_point\n",
    "    \n",
    "    \n",
    "    \"\"\"Function: selects 1024 bp window of X features(one hot encoded ref genome peak range)\n",
    "          1) utilizes mid point from get_mid_point\n",
    "          2) creates a 1024 window size around midpoint\n",
    "    \"\"\"\n",
    "    def get_x_window(self,mid_point, d):\n",
    "        bounds = (mid_point-512, mid_point + 512)\n",
    "        d['1024_bp_start_end'] = bounds\n",
    "        bounds = range(mid_point-512, mid_point + 512)\n",
    "        mid_point = list(bounds)\n",
    "        return mid_point\n",
    "    \n",
    "    \"\"\"Function: selects 256 bp window of Y features(Tn5 frag counts from h5py file)\n",
    "          1) utilize sames mid point from get_mid_point\n",
    "          2) creates a 256 window size around midpoint\n",
    "    \"\"\"\n",
    "    def get_y_window(self,mid_point):\n",
    "        # same thing but without check since 256 bp window is within the 1024bp window\n",
    "        mid_point = list(range(mid_point-256, mid_point + 256))\n",
    "        return mid_point\n",
    "    \n",
    "    \n",
    "    \"\"\"Function: Takes Total Tn5 counts per chromosome and counts per window\n",
    "          Calculates the Atac normalization of Tn5 cut sites counts\n",
    "          Tn5_counts_window = Tn5 bp window that corresponds to the OHE region of the genome (X feature)\n",
    "          C = sum of the Counts per window(summed bp region)\n",
    "          T = total_tn5_counts_per_chrom for h5py file\n",
    "    \"\"\"\n",
    "    def normalize_atac_reads(self, Tn5_counts_window, T,d):\n",
    "        self.Tn5_counts_window = Tn5_counts_window\n",
    "        self.T = T\n",
    "        counts = []\n",
    "        \n",
    "        for i in range(len(Tn5_counts_window)):\n",
    "            C = np.sum(Tn5_counts_window[i])\n",
    "            length = len(Tn5_counts_window[i])\n",
    "            norm = (C*10**9/(T*(length)))\n",
    "            counts.append(norm)\n",
    "        return (counts, C)\n",
    "    \n",
    "    \"\"\"Function: Window size and bins by bin_size\n",
    "          window = 256 bp window\n",
    "          bin_size = 1,2,4,6,8,16,32,64,128,256\n",
    "    \"\"\"\n",
    "    def bin_window(self, window,  bin_size):\n",
    "        self.window = window\n",
    "        self.bin_size = bin_size \n",
    "        length_window = len(window)\n",
    "        bins = []\n",
    "        for i in range(0, length_window, bin_size):\n",
    "            bins.append((window[i:i+bin_size]))\n",
    "        return bins\n",
    "    \"\"\"Function: returns length of dataset of X*Y\n",
    "    \"\"\"\n",
    "    def __len__(self):\n",
    "        return (len(self.unified_bed_file_df)*len(self.frag_tn5_h5_df_1))\n",
    "    \n",
    "    \"\"\"Function: returns X,Y from OHE genome h5py file and Tn5 frag counts h5py file\n",
    "    \n",
    "    \"\"\"\n",
    "    def __getitem__(self,index: int):\n",
    "        \n",
    "        #create sequence of strings for keys\n",
    "        dict_keys = ('Chrom','1024_bp_start_end','Patient_ID','y','C_1', 'Peak_coords', 'T', 'C_1_norm')\n",
    "        # 'C_2', 'C_3', 'C_4',\n",
    "        \n",
    "        #create the dictionary, `my_dictionary`, using the fromkeys() method\n",
    "        d = defaultdict(list).fromkeys(dict_keys)\n",
    "        bam_file_index = (index % len(self.frag_tn5_h5_df_1))\n",
    "        peak_index = (index // len(self.frag_tn5_h5_df_1))\n",
    "        # object that contains chrom, start, end\n",
    "        coords = self.unified_bed_file_df.iloc[peak_index]\n",
    "        # Picks midpoint from coors(bed file peak region)\n",
    "        d['Peak_coords'] = coords.start,coords.end\n",
    "        # Add chrom coords to dictionary \n",
    "        d['Chrom'] = coords.chrom\n",
    "        window = (self.get_mid_point(coords.start,coords.end))\n",
    "        # create X from OHE genome at the bed file peak region with 1024 bp window size\n",
    "        # Window for x and dictionary\n",
    "        temp = self.get_x_window(window, d)\n",
    "        x =  self.one_hot_enc_genome[coords.chrom]['onehot'][temp]\n",
    "        # Add start,end window coords to dictionary\n",
    "        #self.d['1024_bp_start_end'] = temp[::len(temp)-1]\n",
    "        \n",
    "\n",
    "        # To tensor when running on runAI\n",
    "        x = torch.as_tensor(x,  dtype=self.dtype)\n",
    "        # reads in h5py path from Col 1(treatment type) in df of h5py Tn5 insertion cutsite h5 paths\n",
    "        self.get_h5(self.frag_tn5_h5_df_1.iloc[ bam_file_index, self.frag_tn5_h5_df_1.columns.get_loc('diff_frag_h5_path')])\n",
    "        \n",
    "        # Add patient id to dictionary\n",
    "        d['Patient_ID'] = (self.frag_tn5_h5_df_1.iloc[ bam_file_index, self.frag_tn5_h5_df_1.columns.get_loc('patient_id')])\n",
    "        #Gets the h5py path at the bam index from dictionary of h5py paths\n",
    "        h5_tn5_read_count_diff = self.h5_files.get(self.frag_tn5_h5_df_1.iloc[ bam_file_index, self.frag_tn5_h5_df_1.columns.get_loc('diff_frag_h5_path')])\n",
    "        \n",
    "        \n",
    "        # Total total_tn5_counts for all chroms\n",
    "        T_1 = h5_tn5_read_count_diff.attrs[\"total_sum\"]\n",
    "        d['T'] = T_1\n",
    "        #Adds T_1 Tn5 counts value to dictionary\n",
    "        #self.d['T1_diff_counts'] = T_1\n",
    "        \n",
    "        # Sets y1 to 256 bp window from h5py Tn5 insertion cutsites corresponding to OHE genome window\n",
    "        y_1  = (h5_tn5_read_count_diff[coords.chrom]['depth'][self.get_y_window(window)])\n",
    "        # Bin window based off bin_size\n",
    "        y_1 = self.bin_window(window=y_1,bin_size=self.bin_size)\n",
    "        #Normalize each bin windowy_1 = self.bin_window(window=y_1,bin_size=self.bin_size)\n",
    "        norm_1 = self.normalize_atac_reads(y_1, T_1, d)\n",
    "        #self.d['Norm_1'] = norm_1\n",
    "\n",
    "#         # reads in h5py path from Col 2(treatment type) in df of h5py Tn5 insertion cutsite h5 paths\n",
    "#         self.get_h5(self.frag_tn5_h5_df_2.iloc[ bam_file_index, self.frag_tn5_h5_df_2.columns.get_loc('int_gamma_diff_frag_h5_path')])\n",
    "        \n",
    "#         #Gets the h5py path at the bam index from dictionary of h5py paths\n",
    "#         h5_tn5_read_count_int_gamma = self.h5_files.get(self.frag_tn5_h5_df_2.iloc[ bam_file_index, self.frag_tn5_h5_df_2.columns.get_loc('int_gamma_diff_frag_h5_path')])\n",
    "#         # Total total_tn5_counts for all chroms\n",
    "#         T_2 = h5_tn5_read_count_int_gamma.attrs[\"total_sum\"]\n",
    "#         #Adds T_2 Tn5 counts value to dictionary\n",
    "#         #self.d['T2_int_gamma_diff_counts'] = T_2\n",
    "#         # Sets y2 to 256 bp window from h5py Tn5 insertion cutsites corresponding to OHE genome window\n",
    "#         y_2 = (h5_tn5_read_count_int_gamma[coords.chrom]['depth'][self.get_y_window(window)])\n",
    "#         # Bin window based off bin_size\n",
    "#         y_2 = self.bin_window(window=y_2,bin_size=self.bin_size)\n",
    "#         #Normalize bins in window\n",
    "#         norm_2 = self.normalize_atac_reads(y_2, T_2, d)\n",
    "#         #self.d['Norm_2'] = norm_2\n",
    "#         # reads in h5py path from Col 3(treatment type) in df of h5py Tn5 insertion cutsite h5 paths\n",
    "#         self.get_h5(self.frag_tn5_h5_df_3.iloc[ bam_file_index, self.frag_tn5_h5_df_3.columns.get_loc('salm_int_gamma_diff_frag_h5_path')])\n",
    "#         #Gets the h5py path at the bam index from dictionary of h5py paths\n",
    "#         h5_tn5_read_count_salm_int_gamma = self.h5_files.get(self.frag_tn5_h5_df_3.iloc[ bam_file_index, self.frag_tn5_h5_df_3.columns.get_loc('salm_int_gamma_diff_frag_h5_path')])\n",
    "#         # Total total_tn5_counts for all chroms\n",
    "#         T_3 = h5_tn5_read_count_salm_int_gamma.attrs[\"total_sum\"]\n",
    "#         #Adds T_3 Tn5 counts value to dictionary\n",
    "#         #self.d['T3_salm_int_gamma_counts'] = T_3\n",
    "        \n",
    "        \n",
    "#         # Sets y3 to 256 bp window from h5py Tn5 insertion cutsites corresponding to OHE genome window\n",
    "#         y_3 = (h5_tn5_read_count_salm_int_gamma[coords.chrom]['depth'][self.get_y_window(window)])\n",
    "#         # Bin window based off bin_size\n",
    "#         y_3 = self.bin_window(window=y_3,bin_size=self.bin_size)\n",
    "#         #Normalize bins in window\n",
    "#         norm_3 = self.normalize_atac_reads(y_3, T_3, d)\n",
    "#         #self.d['Norm_3'] = norm_3\n",
    "\n",
    "#         # reads in h5py path from Col 4(treatment type) in df of h5py Tn5 insertion cutsite h5 paths\n",
    "#         self.get_h5(self.frag_tn5_h5_df_4.iloc[ bam_file_index, self.frag_tn5_h5_df_4.columns.get_loc('salm_diff_frag_h5_path')])\n",
    "#         #Gets the h5py path at the bam index from dictionary of h5py paths\n",
    "#         h5_tn5_read_count_salm_diff = self.h5_files.get(self.frag_tn5_h5_df_4.iloc[ bam_file_index, self.frag_tn5_h5_df_4.columns.get_loc('salm_diff_frag_h5_path')])\n",
    "#         T_4 = h5_tn5_read_count_salm_diff.attrs[\"total_sum\"]\n",
    "        \n",
    "#         #Adds T_4 Tn5 counts value to dictionary\n",
    "#         #self.d['T4_salm_counts'] = T_4\n",
    "        \n",
    "#         # Sets y4 to 256 bp window from h5py Tn5 insertion cutsites corresponding to OHE genome window\n",
    "#         y_4 = (h5_tn5_read_count_salm_diff[coords.chrom]['depth'][self.get_y_window(window)])\n",
    "#         # Bin window based off bin_size\n",
    "#         y_4 = self.bin_window(window=y_4,bin_size=self.bin_size)\n",
    "#         #Normalize bins in window\n",
    "#         norm_4 = self.normalize_atac_reads(y_4, T_4, d)\n",
    "#         #self.d['Norm_4'] = norm_4\n",
    "#         # concatenate 4 windows(based off bin_size) to create concatenated bp window of four treatments corresponding to OHE genome 1024 bp window\n",
    "#         y = np.concatenate((norm_1[0],norm_2[0],norm_3[0],norm_4[0]))\n",
    "        #c = np.stack((norm_1[1],norm_2[1],norm_3[1],norm_4[1]))\n",
    "        y = norm_1[0]\n",
    "        C_1 = norm_1[1]\n",
    "        C_1_norm = norm_1[0]\n",
    "#         C_2 = norm_2[1]\n",
    "#         C_3 = norm_3[1]\n",
    "#         C_4 = norm_4[1]\n",
    "        d['y'] = y \n",
    "        d['C_1'] = C_1\n",
    "        d['C_1_norm'] = C_1_norm\n",
    "#         d['C_2'] = C_2\n",
    "#         d['C_3'] = C_3\n",
    "#         d['C_4'] = C_4\n",
    "        # To tensor when running on runAI\n",
    "        y = torch.as_tensor(y, dtype=self.dtype)\n",
    "\n",
    "        #Returns tuple of X,y ((1024,4),(bin size related)) size\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Experimental plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "DS_valid = Dataset_class(one_hot_enc_genome, valid_df, fikt_3,bin_size = 512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0., 1., 0., 0.],\n",
       "         [0., 0., 0., 1.],\n",
       "         [0., 0., 1., 0.],\n",
       "         ...,\n",
       "         [0., 0., 0., 1.],\n",
       "         [0., 1., 0., 0.],\n",
       "         [0., 1., 0., 0.]]),\n",
       " tensor([5.9415]))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DS_valid.__getitem__(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports for pytorch lightning and wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "# Pytorch modules\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from pytorch_lightning import LightningDataModule, LightningModule, Trainer\n",
    "import pytorch_lightning as pl\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Definining our model(Naive CC)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitNN(LightningModule):\n",
    "    def __init__(self,seq_len=1024, channels= 4,num_classes=1, n_layer_1=128, n_layer_2=64, n_layer_3=32 ,lr=1e-3):\n",
    "        \"\"\"\n",
    "        init convolution and activation layers\n",
    "        Args:\n",
    "        x: (Nx1x2004)\n",
    "        class: \n",
    "\n",
    "        \"\"\"\n",
    "        super().__init__() \n",
    "        \n",
    "        self.conv1 = torch.nn.Conv1d(channels, n_layer_1, kernel_size=3)\n",
    "        self.batch1 = nn.BatchNorm1d(n_layer_1)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "        self.conv2 = torch.nn.Conv1d(n_layer_1, n_layer_2, kernel_size=3)\n",
    "        self.batch2 = nn.BatchNorm1d(n_layer_2)\n",
    "        self.conv3 = torch.nn.Conv1d(n_layer_2, n_layer_3, kernel_size=3)\n",
    "        self.batch3 = nn.BatchNorm1d( n_layer_3)\n",
    "        self.pool = torch.nn.MaxPool1d(4)\n",
    "        self.fc1 = torch.nn.Linear(480, num_classes)\n",
    "        \n",
    "        # optimizer parameters\n",
    "        self.lr = lr\n",
    "\n",
    "        # metrics\n",
    "        self.R2score_normal = torchmetrics.R2Score(num_outputs = 1)\n",
    "        self.R2score_flipped = torchmetrics.R2Score(num_outputs = 1)\n",
    "        self.MSE = torchmetrics.MeanSquaredError()\n",
    "        self.pearson_cor = torchmetrics.PearsonCorrCoef()\n",
    "\n",
    "        # optional - save hyper-parameters to self.hparams\n",
    "        # they will also be automatically logged as config parameters in W&B\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        forward function describes how input tensor is transformed to output tensor\n",
    "        Args:\n",
    "            \n",
    "        \"\"\"\n",
    "        \n",
    "        batch_size,  seq_len, channels = x.size()\n",
    "        # Changes the order to batch_size, channels, seq_len \n",
    "        x = x.permute(0,2,1)\n",
    "        x = self.conv1(x)\n",
    "        x = self.batch1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.pool(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.batch2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.pool(x)\n",
    "        x = self.conv3(x)\n",
    "        x = self.batch3(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.pool(x)\n",
    "        #Flatten layer for fully connected layer\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        return x\n",
    "    \n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        '''needs to return a loss from a single batch'''\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.mse_loss(logits, y)\n",
    "\n",
    "        # Log training loss\n",
    "        self.log('train_loss', loss)\n",
    "\n",
    "        # Log metrics\n",
    "        self.log('train_flipped_R2', self.R2score_flipped(y,logits))\n",
    "        self.log('train_normal_R2', self.R2score_normal(logits,y))\n",
    "        self.log('train_mse', self.MSE(logits, y))\n",
    "        self.log('train_pearson_corr', self.pearson_cor(logits, y))\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        '''used for logging metrics'''\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.mse_loss(logits, y)\n",
    "        \n",
    "        # Log validation loss (will be automatically averaged over an epoch)\n",
    "        self.log('valid_loss', loss)\n",
    "\n",
    "        # Log metrics\n",
    "        self.log('valid_R2_flipped_R2', self.R2score_flipped(y,logits))\n",
    "        self.log('valid_R2_normal_R2', self.R2score_normal(logits,y))\n",
    "        self.log('valid_rmse', self.MSE(logits, y))\n",
    "        self.log('valid_peasron_corr', self.pearson_cor(logits, y))\n",
    "        return loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        '''used for logging metrics'''\n",
    "        x, y = batch\n",
    "        logits = self(x)\n",
    "        loss = F.mse_loss(logits, y)\n",
    "\n",
    "        # Log test loss\n",
    "        self.log('test_loss', loss)\n",
    "\n",
    "\n",
    "        # Log metrics\n",
    "        self.log('test_R2_flipped_R2', self.R2score_flipped(y,logits))\n",
    "        self.log('test_R2_normal_R2', self.R2score_normal(logits,y))\n",
    "        self.log('test_rmse', self.MSE(logits, y))\n",
    "        self.log('test_pearson_corr', self.pearson_cor(logits, y))\n",
    "        \n",
    "        return loss\n",
    "    def predict_step(self, batch, batch_idx):\n",
    "        x = batch\n",
    "        pred = self(x)\n",
    "        return pred\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        '''defines model optimizer'''\n",
    "        return Adam(self.parameters(), lr=self.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# Data Loader functions(train, valid, test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define collate function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data.sampler import WeightedRandomSampler\n",
    "import torch\n",
    "\n",
    "class ATACDataModule(pl.LightningDataModule):\n",
    "    def __init__(self, one_hot_enc_genome, train_bed_file_df,valid_bed_file_df,test_bed_file_df,frag_tn5_train_df,frag_tn5_valid_df,frag_tn5_test_df, batch_size=128):\n",
    "        super().__init__()\n",
    "        self.one_hot_enc_genome = one_hot_enc_genome\n",
    "        self.train_bed_file_df = train_bed_file_df\n",
    "        self.valid_bed_file_df = valid_bed_file_df\n",
    "        self.test_bed_file_df = test_bed_file_df\n",
    "        self.frag_tn5_train_df = frag_tn5_train_df\n",
    "        self.frag_tn5_valid_df = frag_tn5_valid_df\n",
    "        self.frag_tn5_test_df = frag_tn5_test_df\n",
    "        self.batch_size = batch_size\n",
    "    # collate function for removing none from batch\n",
    "    # def collate_fn(self, batch):\n",
    "    #     batch = list(filter(lambda x: x is not None, batch))\n",
    "    #     return torch.utils.data.dataloader.default_collate(batch)\n",
    "    def setup(self, stage=None):\n",
    "        '''called on each GPU separately - stage defines if we are at fit or test step'''\n",
    "        # we set up only relevant datasets when stage is specified (automatically set by Pytorch-Lightning)\n",
    "        if stage == 'fit' or stage is None:\n",
    "            self.train_ds = Dataset_class(one_hot_enc_genome, train_bed_file_df, frag_tn5_train_df, dtype = torch.float32, bin_size = 512)\n",
    "            self.valid_ds = Dataset_class(one_hot_enc_genome, valid_bed_file_df, frag_tn5_valid_df,dtype = torch.float32, bin_size = 512)\n",
    "        if stage == 'test' or stage is None:\n",
    "            self.test_ds = Dataset_class(one_hot_enc_genome, test_bed_file_df, frag_tn5_test_df,bin_size = 512)\n",
    "            \n",
    "    def train_dataloader(self):\n",
    "        '''returns training dataloader'''\n",
    "        sampler = WeightedRandomSampler(weights=torch.ones(len(self.train_ds)), num_samples=6400)\n",
    "        train_dl = DataLoader(self.train_ds, sampler=sampler,batch_size=self.batch_size, num_workers=4)\n",
    "        return train_dl\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        '''returns validation dataloader'''\n",
    "        valid_dl = DataLoader(self.valid_ds,shuffle=False, batch_size=self.batch_size,num_workers=4)\n",
    "        return valid_dl\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        '''returns test dataloader'''\n",
    "        test_dl = DataLoader(self.test_ds,  shuffle=False,  batch_size=self.batch_size,num_workers=4)\n",
    "        # collate_fn=self.collate_fn\n",
    "        # , num_workers=4\n",
    "        return test_dl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WandB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb_logger = WandbLogger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchmetrics\n",
    "\n",
    "# Setup  model\n",
    "model = LitNN(seq_len=1024, channels= 4,num_classes=1, n_layer_1=128, n_layer_2=64, n_layer_3=32,lr=1e-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'train_bed_file_df' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [39]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# Instantiate ATAC Data module\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m atac \u001b[38;5;241m=\u001b[39m ATACDataModule(one_hot_enc_genome, \u001b[43mtrain_bed_file_df\u001b[49m,valid_bed_file_df,test_bed_file_df,fikt_3,fikt_3,fikt_3, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'train_bed_file_df' is not defined"
     ]
    }
   ],
   "source": [
    "# Instantiate ATAC Data module\n",
    "atac = ATACDataModule(one_hot_enc_genome, train_bed_file_df,valid_bed_file_df,test_bed_file_df,fikt_3,fikt_3,fikt_3, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
    "# # Early stopping \n",
    "# early_stop_callback = EarlyStopping(monitor=\"valid_loss\", patience=10, mode=\"min\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Trainer initializer\n",
    "trainer = pl.Trainer(\n",
    "    # callbacks=[early_stop_callback],\n",
    "    # auto_lr_find=True,\n",
    "    gradient_clip_val=1,\n",
    "    logger=wandb_logger,\n",
    "    #overfit_batches= 0.1,\n",
    "    # fast_dev_run = True,\n",
    "    accelerator='auto', # W&B integration \n",
    "    #devices=1 if torch.cuda.is_available() else None,\n",
    "    max_epochs=100 # number of epochs,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose best LR automatically\n",
    "trainer.tune(model,atac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_finder = trainer.tuner.lr_find(model, atac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = lr_finder.plot(suggest=True)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the LR with the best chosen LR\n",
    "model.lr = 1e-7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model\n",
    "trainer.fit(model, atac)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test model\n",
    "trainer.test(model, atac)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clos wandb\n",
    "wandb.finish()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "70048b412bdf50bb9639879095d5b7e9588630cc3326e9b869d915719d9eeab2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
