#!/usr/bin/python3

import sys
import os
import re
import fcntl
import socket
import json
import datetime
import traceback
import logging
import logging.handlers
import locale
import subprocess
import threading
import time
import pymysql
from operator import itemgetter
from GRMConst import MySQL_URL, DISCONNECT_TIMER
from dbutils.pooled_db import PooledDB

logging.basicConfig(
  level = logging.INFO,
  format = '%(asctime)s %(levelname)-8s %(message)s',
  datefmt = '%Y-%m-%d %H:%M:%S',
  handlers = [ logging.handlers.TimedRotatingFileHandler('/var/log/locker.log', when='MIDNIGHT',backupCount=31, encoding='utf-8') ]
)

pymysql.install_as_MySQLdb()

MySQLdb = pymysql

pool = PooledDB(creator=MySQLdb, blocking=True, maxconnections=50, setsession=[ 'set session transaction isolation level read committed' ], **MySQL_URL)

class DB:
  def __enter__(self):
    self.db = pool.connection()
    self.cur = self.db.cursor()
    self.db.begin();
    return self.cur

  def __exit__(self, exc_type, exc_value, traceback):
    if exc_type is None:
      self.db.commit()
    self.cur.close()
    self.db.close()
    self.cur = None
    self.db = None

#locale.setlocale(locale.LC_ALL, 'ja_JP')

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

def debug(*args):
  #print(*args,flush=True)
  logger.debug(''.join(map(str,args)))

def error(*args):
  #print(*args,flush=True)
  logger.error(''.join(map(str,args)))

def get_conn(connId,mode):
  with DB() as cur:
    try:
      if connId != 0:
        cur.execute("SELECT conn_id FROM conn WHERE conn_id=%s", (connId,))
        rec = cur.fetchone()
        if rec is not None:
          cur.execute("UPDATE conn SET disconnected=NULL WHERE conn_id=%s", (connId,))
          return rec[0]
      cur.execute("INSERT INTO conn (conn_id,mode) VALUES (NULL,%s)",(mode,))
      cur.execute("SELECT LAST_INSERT_ID()")
      rec = cur.fetchone()
      return rec[0]
    except:
      error(traceback.format_exc())
      sys.exit(1)

def shared_lock(cur, sqlparams):
  sqlparams['lock'] = 0
  cnt = len(sqlparams['inodes'])
  for i in range(0,cnt):
    inode = sqlparams['inodes'][i]
    name  = sqlparams['files'][i]
    sqlparams['inode'] = inode
    sqlparams['path']  = name.encode('utf-8')
    cur.execute("SELECT IFNULL(sum(ex),0),IFNULL(sum(sh),0) FROM lock_file WHERE inode=%(inode)s ORDER BY id FOR UPDATE", sqlparams)
    rec = cur.fetchone()
    if rec[0] == 0:
      cur.execute("INSERT INTO lock_file (top,inode,path,ex,sh,lock_id,conn_id,ins_log,ins_date) VALUES (%(top)s,%(inode)s,%(path)s,0,1,%(lock)s,%(conn)s,'locker.py',NOW())", sqlparams)
      if sqlparams['lock'] == 0:
        cur.execute("SELECT LAST_INSERT_ID()");
        rec = cur.fetchone()
        sqlparams['lock'] = rec[0]
    else:
      cur._get_db().rollback()
      return (1,0)
  cur.execute("UPDATE lock_file SET lock_id=%(lock)s WHERE id=%(lock)s", sqlparams)
  return (0, sqlparams['lock'])

def exclusive_lock(cur, sqlparams):
  sqlparams['lock'] = 0
  cnt = len(sqlparams['inodes'])
  for i in range(0,cnt):
    inode = sqlparams['inodes'][i]
    name  = sqlparams['files'][i]
    sqlparams['inode'] = inode
    sqlparams['path']  = name.encode('utf-8')
    cur.execute("SELECT IFNULL(sum(ex),0),IFNULL(sum(sh),0) FROM lock_file WHERE inode=%(inode)s ORDER BY id FOR UPDATE", sqlparams)
    rec = cur.fetchone()
    if rec[0] == 0 and rec[1] == 0:
      cur.execute("INSERT INTO lock_file (top,inode,path,ex,sh,lock_id,conn_id,ins_log,ins_date) VALUES (%(top)s,%(inode)s,%(path)s,1,0,%(lock)s,%(conn)s,'locker.py',NOW())", sqlparams)
      if sqlparams['lock'] == 0:
        cur.execute("SELECT LAST_INSERT_ID()");
        rec = cur.fetchone()
        sqlparams['lock'] = rec[0]
    else:
      cur._get_db().rollback()
      return (1,0)
  cur.execute("UPDATE lock_file SET lock_id=%(lock)s WHERE id=%(lock)s", sqlparams)
  return (0, sqlparams['lock'])

def update(params):
  top = params[1]
  with DB() as cur:
    try:
      m = re.search(r'/([^/]*)/$',top)
      if m is not None:
        cur.execute("UPDATE ftpquota SET last_update=NOW(),upd_log='locker.py',upd_date=NOW() WHERE id=%s", (m.group(1),))
      return "0"
    except:
      error(traceback.format_exc())
  return "1"

def lock(connId, params):
  ret  = (1,0)
  top = params[1]
  inodes = params[2].split(',')
  lock = params[3]
  files = params[4].split('|')
  sqlparams = { 'conn' : connId, 'top' : top, 'inodes' : inodes, 'files' : files }
  with DB() as cur:
    try:
      if lock == '0': # shared lock
        ret = shared_lock(cur, sqlparams)
      elif lock == '1': # exclusive lock
        ret = exclusive_lock(cur, sqlparams)
    except:
      error(traceback.format_exc())
  return "{0} {1}".format(*ret)

def unlock(params):
  top = params[1]
  lock = int(params[3])
  sqlparams = { 'id' : lock }
  with DB() as cur:
    try:
      cur.execute("DELETE FROM lock_file WHERE lock_id=%(id)s", sqlparams)
      m = re.search(r'/([^/]*)/$',top)
      if m is not None:
        cur.execute("UPDATE ftpquota SET last_update=NOW(),upd_log='locker.py',upd_date=NOW() WHERE id=%s", (m.group(1),))
      return "0"
    except:
      error(traceback.format_exc())
  return "1"

def doClose(connId):
  with DB() as cur:
    try:
      cur.execute("SELECT mode FROM conn WHERE conn_id=%s", (connId,))
      rec = cur.fetchone()
      if rec is None:
        return
      if rec[0] == '0':
        cur.execute("UPDATE conn SET disconnected=NOW() WHERE conn_id=%s",(connId,))
      else:
        cur.execute("SELECT * FROM lock_file WHERE conn_id=%s ORDER BY id FOR UPDATE", (connId,))
        cur.execute("DELETE FROM lock_file WHERE conn_id=%s", (connId,))
        cur.execute("DELETE FROM conn WHERE conn_id=%s", (connId,))
    except:
      error(traceback.format_exc())

def doShutdown(connId):
  with DB() as cur:
    try:
      cur.execute("DELETE FROM lock_file WHERE conn_id=%s", (connId,))
      cur.execute("DELETE FROM conn WHERE conn_id=%s", (connId,))
    except:
      error(traceback.format_exc())

def accepted(conn):
  conId = 0

  try:
    while True:
      head = conn.recv(10)
      if len(head) == 0:
        break

      debug('head: ', head)
      rlen = int(head.decode('utf-8').strip())
      params = conn.recv(rlen).decode('utf-8')
      debug('params: ', params)
      cols = params.split(':')
      debug('rlen: ', rlen, ' params: ', cols)
      
      cmd = cols[0]
      if cmd == 'c':
        conId = get_conn(int(cols[1]),cols[2])
        ret = "{}".format(conId)
      elif cmd == 'd':
        doShutdown(conId)
        conId = 0
        ret = "0"
      elif cmd == '+':
        ret = lock(conId,cols)
      elif cmd == '-':
        ret = unlock(cols)
      elif cmd == '!':
        ret = update(cols)
      else:
        ret = "0"
    
      wlen = len(ret)
      buff = "{:<9d} {}".format(wlen, ret)
      debug('write: ', buff)
      conn.send(buff.encode('utf-8'))
  except:
    error(traceback.format_exc())

  doClose(conId)
  conn.close()
  return

def doCleanup():
  with DB() as cur:
    try:
      cur.execute("DELETE FROM conn WHERE disconnected < ADDTIME(NOW(),%s)", (DISCONNECT_TIMER,))
      cur.execute("DELETE FROM lock_file WHERE NOT EXISTS(SELECT * FROM conn WHERE conn_id=lock_file.conn_id)")
    except:
      error(traceback.format_exc())

def startCleanup():
  def cleanup():
    while True:
      time.sleep(60)
      doCleanup()
  t = threading.Thread(target=cleanup)
  t.daemon = True
  t.start()

def main():
  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  s.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
  s.bind(("127.0.0.1",12345,))
  s.listen(5)
  startCleanup()
  while True:
    try:
      conn, addr = s.accept()
      t = threading.Thread(target=accepted,args=(conn,))
      t.daemon = True
      t.start()
    except:
      error(traceback.format_exc())

main()

