xplor.requireVersion("2.14.3")

totStructs = 2          # number of structures to calculate
ensembleSize=2          # number of ensemble members
seed=5
outPDB_template = "SCRIPT_%d_STRUCTURE_MEMBER.sa" % ensembleSize

xplor.parseArguments() # check for typos on the command-line

command = xplor.command
                                     
#atom selection used to determine protein center
transCenterSel = "name C or name N or name CA"



simWorld.setRandomSeed(seed)
import sys
sys.stderr = sys.stdout #stderr is redirected to stdout



#----------------------------------------------------------------------
# read in the PSF file and initial structure
import protocol
protocol.initStruct("gb3.psf")

protocol.initParams("protein.par",
                    weak_omega=1)

protocol.initCoords("gb3_refined_IV.pdb")

#
# create Simulation with ensembleSize ensemble members
#
#  do this before defining potential terms, or IVM objects
#
from ensembleSimulation import EnsembleSimulation
esim = EnsembleSimulation("ensemble",ensembleSize)
print "ensemble size", esim.size(),"running on",esim.numThreads(),"processors."

potList = PotList()
crossTerms = PotList("cross terms")

from simulationTools import MultRamp, StaticRamp, InitialParams
rampedParams = []


#
# now initialize dipolar coupling potential terms
#

from varTensorTools import create_VarTensor, calcTensor
media={}
for medium in ["peg","bic","ngel","pgel","pf1"]:
    media[medium] = create_VarTensor(medium)
    pass

def create_and_setup_rdc(name,file):
    axis = addAxisAtoms()
    print "axis:", axis
    oTensor=create_VarTensor(name,axis=axis)
    name = "nh_" + name
    write("creating rdc[%8s]..." %name)
    rdc = create_RDCPot(name,oTensor=oTensor,file=file)
    calcTensor(oTensor)
    #if esim.member().memberIndex()==0:
    #    AtomSel("resname ANI and name X",esim.members(0))[0].setPos(3,2,1)
    #    pass
    #AtomSel("resname ANI and name PA1",esim).apply(PrintPos())
    print " Da = %6.2f   Rh = %6.3f" % (oTensor.Da(), oTensor.Rh())
    #oTensor.setEnsembleDa(1)
    #oTensor.setEnsembleRh(1)
    # Da(0): want potential parameters to be *exactly* the same
    # for all simulation members
    #oTensor.setSpreadDa(1e-5*abs(oTensor.Da(0))) 
    #oTensor.setSpreadRh(1e-5)
    #oTensor.setScaleDa(1e5)
    #oTensor.setScaleRh(1e5)
    return rdc

def create_dependent_rdc(name,file,type,rdc_nh):
    """create an rdc pot term for a non-NH expt. carried out in the
    same medium as the given nh expt.
    """
    write("creating rdc[%8s]..." %name)
    oTensor = rdc_nh.oTensor
    rdc = create_RDCPot(name,file,oTensor=oTensor)
    scale_toNH(rdc,type)
    print " Da = %6.2f   Rh = %6.3f" % (oTensor.Da(), oTensor.Rh())
    rdc.setAveType("sum")
    #oTensor.setEnsembleDa(1)
    #oTensor.setEnsembleRh(1)
    #oTensor.setSpreadDa(1e-5*abs(rdc_nh.oTensor.Da(0)))
    #oTensor.setSpreadRh(1e-5)
    #oTensor.setScaleDa(1e5)
    #oTensor.setScaleRh(1e5)
    return rdc
    

refineRDC = PotList("refineRDC")
#
# rdc terms are scaled relative to rdc_scaleRef: the scale factor
# is the square of the ratio of Das
#
rdcs={}
rdc_scaleRef="nh_peg"
from rdcPotTools import create_RDCPot, scale_toNH
for (expt,medium,file,weight) in [
    ("nh"   ,"peg" ,"peg_nh.tbl"      ,15.0),
    ("nh"   ,"bic" ,"bicelle_nh.tbl"  ,15.0),
    ("nh"   ,"ngel","ngel_nh.tbl"     ,15.0),
    ("nh"   ,"pgel","pgel_nh.tbl"     ,15.0),
    ("nh"   ,"pf1" ,"pf1_nh.tbl"      ,15.0),
    ("caco" ,"bic" ,"bicelle_caco.tbl",100.0),
    ("caco" ,"ngel","ngel_caco.tbl"   ,100.0),         
    ("caco" ,"pgel","pgel_caco.tbl"   ,100.0),
    ("caco" ,"pf1" ,"pf1_caco.tbl"    ,100.0),
    ("caco" ,"peg" ,"peg_caco.tbl"    ,100.0),
    ("ch"   ,"bic" ,"bicelle_ch.tbl"  ,3.0),
    ("ch"   ,"ngel","ngel_ch.tbl"     ,3.0),
    ("ch"   ,"pgel","pgel_ch.tbl"     ,3.0),
    ("ch"   ,"pf1" ,"pf1_ch.tbl"      ,3.0),
    ("ch"   ,"peg" ,"peg_ch.tbl"      ,3.0),
    ("nco"  ,"bic" ,"bicelle_nco.tbl" ,100.0),
    ("nco"  ,"ngel","ngel_nco.tbl"    ,100.0),
    ("nco"  ,"pgel","pgel_nco.tbl"    ,100.0),
    ("nco"  ,"pf1" ,"pf1_nco.tbl"     ,100.0),
    ("nco"  ,"peg" ,"peg_nco.tbl"     ,100.0)    ]:
    name = expt + '_' + medium
    oTensor = media[medium]
    term = create_RDCPot(name,file=file,oTensor=oTensor)
    term.setScale(weight)
    scale_toNH(term)
    term.setAveType("sum")
    refineRDC.append( term )
    rdcs[name] = term
    pass

# 
# calc initial alignment tensors using only the NH experiments
for medium in media.keys():
    calcTensor( media[medium], (rdcs["nh_" + medium],) )
    pass

#
# rescale the individual term's weighting by the respective Da
for rdc in rdcs.values():
    weight = rdc.scale() * (rdcs[rdc_scaleRef].oTensor.Da(0) /
                            rdc.oTensor.Da(0) )**2
    rdc.setScale( weight )
    pass
potList.append( refineRDC )
rampedParams.append( MultRamp(0.01,1.,"refineRDC.setScale( VALUE )") )

#
#add in tensor terms (to prevent large spread in Da, Rh)
#  --> used if ensembleDa or ensembleRh are enabled
#tensorPots  = PotList("tensors")
#for medium in media:
#    tensorPots.add( medium )
#    pass
#potList.append( tensorPots )

#cross validation: remove one expt in four media
#
crossRDC  = PotList("crossRDC")
for term in ["nh_bic","ch_ngel","caco_pgel","nco_pf1"]:
    refineRDC.remove(term)
    crossRDC.append(rdcs[term])
    pass
crossTerms.append( crossRDC )

from avePot import AvePot
from xplorPot import XplorPot

protocol.initCollapse(sel="resid 1:56",
                      Rtarget=10.6)
potList.append( AvePot(XplorPot,"COLL") )


command("""
hbda
  nres 800
  class back
  @hbda.tbl
  force 500.0
end
""")
potList.append( AvePot(XplorPot,"HBDA") )


protocol.initDihedrals(scale=200.,
                       filenames="dihed_g_all.tbl",
                       useDefaults=0)

command("print threshold=0.0 cdih")
potList.append( AvePot(XplorPot,"CDIH") )
rampedParams.append(MultRamp(200.,200.,
                             "command('restraints dihed scale VALUE end')"))

from jCoupPotTools import create_JCoupPot
crossTerms.append( create_JCoupPot('hnha','jna_coup.tbl') )

jside = PotList('jside')
for (name,table) in [('tccg','J_thr_ccg.tbl'),
                     ('tncg' ,'J_thr_ncg.tbl'    ),
                     ('ivccg','J_val_ile_ccg.tbl'),
                     ('ivncg','J_val_ile_ncg.tbl')]:
    jside.append( create_JCoupPot(name,table) )
    pass
potList.append(jside)


from noePotTools import create_NOEPot
# do not use r^{-6} averaging for this term: instead apply the restraint to
# each ensemble member
#
# to get 1/r^6 averaging, use this instead.
#   enoe = create_NOEPot('enoe','file.tbl')
#
enoe = AvePot(create_NOEPot('enoe','hbond.tbl',
                            esim=esim.member().subSim()))

enoe.setScale( 30. ) #initial value
enoe.setDOffset(0.)
enoe.setHardExp(2 )
enoe.setThreshold(0.1)
enoe.setPotType( "hard" )
enoe.setAveType( "sum" )

potList.append( enoe ) 

rampedParams.append(MultRamp(2.0,30.0,
                             "enoe.setScale( VALUE );" + 
                             "command('noe scale all VALUE end')"))
#rampedParams.append( MultRamp(2.0,6.0, "enoe.setAveExp( VALUE )") )




# find geometric center of protein
from vec3 import Vec3
centerSel = AtomSel("not resname ANI") # selection used to find center
centerPos = Vec3(0,0,0)
for atom in centerSel: centerPos += atom.pos()
centerPos /= len( centerSel )

print "center: ", centerPos

# exterior atoms defining protein shape
shapeSel = AtomSel("(not resname ANI) and not (point (%f, %f, %f) cut 5.)" %
                   centerPos.tuple())


import shapePotTools
from shapePot import ShapePot
orient = PotList("orient")
shape = PotList("shape")

# resids of regions of secondary structure (inclusive)
orientRegions = [("a",[(24,35)]),                       #helix
                 ("b",[(2,7),(15,18),(42,46),(51,55)]), #sheet (4 strands)
                 ("all",[(2,55)])]                      

for (name,rangePairs) in orientRegions:
    selStr = "(name CA and resid %d:%d)" % rangePairs[0]
    if len(rangePairs)>1:
        for rangePair in rangePairs[1:]:
            selStr += " or (name CA and resid %d:%d)" % rangePair
            pass
        pass
    pot = ShapePot("shape_"+name,selStr)
    pot.setSizeScale( 10)
    pot.setOrientScale( 0 )
    pot.setSizePotType("square") #allow 1 unit size error
    pot.setSizeTol(1)
    pot.setTargetType("pairwise")
    shape.append(pot)
    #
    pot = ShapePot("orient_"+name,selStr)
    pot.setSizeScale( 0 )
    pot.setOrientScale( 50 )
    pot.setOrientPotType("square") #allow 1 degree orientation error
    pot.setOrientTol(0.)
    pot.setTargetType("pairwise")
    orient.append(pot)
    pass

potList.add( orient )
potList.add( shape )


from posRMSDPotTools import create_BFactorPot
bFactor = PotList("bFactor")
bFactor.setScale(0.1)
potList.append(bFactor)
#crossTerms.append(bFactor)
rampedParams.append( MultRamp(0.5,0.5,"bFactor.setScale( VALUE)") )

for (name,table) in [
    ("b_side","bfactor_side.tbl"),
    ("b_back","bfactor_back.tbl")]:
    exec("%s = create_BFactorPot('%s',file='%s',centerSel=transCenterSel)" %
         (name,name,table))
    pass

bFactor.append(b_back)
crossTerms.append(b_side)

#rap term to keep members of ensemble in similar orientations
from posRMSDPotTools import RAPPot
rap = RAPPot("rap",AtomSel("name CA"))

rap.setScale( 100.0 )
rap.setPotType( "square" )
rap.setTol( 1.5 )
#potList.add( rap )
crossTerms.append(rap)

from orderPotTools import create_OrderPot
orderPot = create_OrderPot("s2_nh","nh_s2.tbl")
orderPot.setScale(0.01)
potList.add( orderPot )
#crossTerms.append(orderPot)
rampedParams.append( MultRamp(0.01,0.3,"orderPot.setScale( VALUE)") )

#
# ensemble average over the following XPLOR energy terms
#
potList.append( AvePot(XplorPot,"BOND") )

potList.append( AvePot(XplorPot,"ANGL") )
rampedParams.append( MultRamp(0.4,1.0,"potList['ANGL'].setScale(VALUE)") )
potList['ANGL'].setScale(0.4) #initial value

potList.append( AvePot(XplorPot,"IMPR") )
rampedParams.append( MultRamp(0.1,1.0,"potList['IMPR'].setScale(VALUE)") )
potList['IMPR'].setScale(0.1)


protocol.initNBond()
potList.add( AvePot(XplorPot,"VDW") )
#vdw weight -- FIX??
#rcon  = 0.003
#command("parameters nbonds atom rcon=%f end end" % rcon)
rampedParams.append( MultRamp(0.81,0.8,
                     'command("parameter nbonds repel=VALUE end end")') )
rampedParams.append( MultRamp(0.1,4.0,
          'command("parameter nbonds rcon=VALUE end end")') )

protocol.initRamaDatabase()
potList.add( AvePot(XplorPot,"RAMA") )
rampedParams.append(MultRamp(0.002,1.0,
                             "command('rama scale VALUE end')"))

#
# minimization to optimize initial alignment orientation
import ivm
mini = ivm.IVM(esim)

import varTensorTools
for medium in media.values(): medium.setFreedom("fixDa, fixRh")
varTensorTools.topologySetup(mini,media.values())

protocol.initMinimize(mini,potList)
mini.fix( AtomSel("not resname ANI") )
mini.autoTorsion()

mini.run()

#
# let Da, Rh float in main calculation
for medium in media.values(): medium.setFreedom("varyDa, varyRh")


#set up IVM dynamics manipulators

dyn  = ivm.IVM(esim) #manipulator for T-A dynamics
dynx = ivm.IVM(esim) #manipulator for Cartesian dynamics

protocol.torsionTopology(dyn,oTensors=media.values())
protocol.cartesianTopology(dynx,oTensors=media.values())


#set atom masses
import trace
trace.suspend()
for atom in AtomSel("not (resname ANI)"): atom.setMass(30.)
varTensorTools.massSetup(media.values(),10000)
for atom in AtomSel("all"): atom.setFric(10.)
trace.resume()

#dynamics run to break symmetry- moving the ensemble members apart.
initPotList = PotList()
# potential terms exclude s2 and bfactor terms, but include rap
for term in potList:
    if term.instanceName() == "s2_nh" or \
       term.instanceName() == "rap" or \
       term.instanceName() == "bFactor":
        continue
    initPotList.append(term)
    pass
initPotList.append(rap)
initPotList.append(bFactor)

refineRDC.setScale(0.1)

protocol.initDynamics(dyn,
                      bathTemp=400,
                      initVelocities=1,
                      potList=initPotList,
                      finalTime=1)

dyn.run()



class DynMin:
    """class to perform minimization followed by dynamics
    """
    def __init__(s,ivm):
        s.ivm = ivm
        return
    def setBathTemp(s,temp):
        s.ivm.setBathTemp(temp)
        return
    def setETolerance(s,tol):
        s.ivm.setETolerance(tol)
        return
    def run(s):
        #minimize axis atoms
        #for t in media.values():
        #    calcTensor(t)
        dNumSteps  = s.ivm.numSteps()
        dFinalTime = s.ivm.finalTime()
        vel = esim.atomVelArr()
        for i in range(1):
            protocol.initMinimize(s.ivm,numSteps=200)
            s.ivm.run()
            pass
        esim.setAtomVelArr(vel)
        protocol.initDynamics(s.ivm,
                              finalTime=dFinalTime,
                              numSteps=dNumSteps)

        s.ivm.run()
        return

    pass



#
# definition of cooling loop and associated parameters
#

from simulationTools import AnnealIVM
def coolAndMinimize():


    #
    # cool performs MD simulated annealing
    #
    init_t  = 400.01
    cool = AnnealIVM(initTemp     = init_t,
                     finalTemp    = 300.0,
                     tempStep     = 25.0,
                     ivm          = DynMin(dyn),
                     rampedParams = rampedParams)


    # initialize parameters 
    InitialParams( rampedParams )

    # initial minimization
    protocol.initMinimize(dyn,
                          numSteps=1000,
                          potList=potList)
    dyn.run()

    protocol.initDynamics(dyn,
                          initVelocities=1,
                          bathTemp=init_t,
                          potList=potList)

    #steps and endtime for dynamics
    cool_steps = 100 #total number of dynamics steps during cooling
    nstep = int(cool_steps/(cool.numSteps+1) )
    endtime = nstep*0.002

    dyn.setFinalTime(endtime)

    cool.run()

    #
    # t/a minimization
    protocol.initMinimize(dyn,
                          numSteps=500)
    dyn.run()

    #
    #minimization in cartesian space
    #
    protocol.initMinimize(dynx,potList)
    dynx.run()

    return



#
# main structure loop
#

def calcOneStructure(loopInfo):
    numReps = 4 # should be 16
    k_shape = MultRamp(0.1,1.,
                       "orient.setScale( VALUE )")
    k_shape.init( numReps )
    for repetition in range( numReps ):
        k_shape.update()
        coolAndMinimize()
        pass
    loopInfo.writeStructure(potList,crossTerms)


from simulationTools import StructureLoop
StructureLoop(numStructures=totStructs,
              structLoopAction=calcOneStructure,
              pdbTemplate=outPDB_template,
              genViolationStats=1,
              averageTopFraction=0.3,
              averagePotList=potList,
              averageCrossTerms=crossTerms).run()