StEmcUtil/neuralNet/NeuNet.h

00001 00002 // 00003 // Neural Network classes : 00004 // TNNFormula 00005 // TNNTree 00006 // TNNKernel 00007 // TNNControlE 00008 // TNNUtils 00009 // J.P. Ernenwein (rnenwein@in2p3.fr) 00011 00013 #ifndef ROOT_TNamed 00014 #include "TNamed.h" 00015 #endif 00016 #ifndef ROOT_TROOT 00017 #include "TROOT.h" 00018 #endif 00019 #ifndef ROOT_TTree 00020 #include "TTree.h" 00021 #endif 00022 #ifndef ROOT_TString 00023 #include "TString.h" 00024 #endif 00025 #ifndef ROOT_TFormula 00026 #include "TFormula.h" 00027 #endif 00028 #ifndef ROOT_TTreeFormula 00029 #include "TTreeFormula.h" 00030 #endif 00031 #ifndef ROOT_TCanvas 00032 #include "TCanvas.h" 00033 #endif 00034 #include "TFrame.h" 00035 #include "TStringLong.h" 00036 #include "TH1.h" 00037 #include "TGraph.h" 00038 #include "TAxis.h" 00039 #include "TFile.h" 00040 #include "TText.h" 00041 #include "TDatime.h" 00042 #include "TRandom.h" 00043 #include "TPad.h" 00044 #include "math.h" 00045 #include "stdlib.h" 00046 00047 00049 00050 class TNNFormula : public TNamed 00051 { 00052 00053 private: 00054 Int_t fNValues; // number of values 00055 TTree *fTree; // current Tree on which values are computed 00056 TTreeFormula *fTTCut; // Tree formula containing the cut expression 00057 TTreeFormula **fTTFormula; 00058 Bool_t fClip; // flag to clip or not the values (1 = clip, 0 = no clip) 00059 TStringLong* fFormula; 00060 TStringLong* fCut; 00061 Bool_t fRefresh; // refresh flag 00062 Int_t RMBlanks(Text_t *str); 00063 Float_t Clip(Float_t x); 00064 00065 public: 00066 TNNFormula(){fNValues=0;fClip=1;fTree=0;fTTCut=0;fTTFormula=0;fFormula=0;fCut=0;fRefresh=0;}; 00067 TNNFormula(Text_t *name, Text_t *formula, Text_t *cut, TTree *tree); 00068 virtual ~TNNFormula(); 00069 virtual Bool_t Find(Int_t iEvent, Float_t *values); 00070 virtual void Find(Int_t iEvent=0); 00071 virtual Int_t GetNValues(){return fNValues;}; 00072 virtual void SetTree(TTree *tree); 00073 virtual void SetFormula(Text_t *formula); 00074 virtual void Refresh() 00075 { 00076 fRefresh=1; 00077 if(fFormula)SetFormula((Text_t*)(fFormula->Data())); 00078 if(fCut)SetCut((Text_t*)(fCut->Data())); 00079 fRefresh=0; 00080 }; 00081 virtual Int_t Length(){if(fFormula)return fFormula->Length();else return 0;}; 00082 virtual void SetClip(Bool_t trueForClip=1); 00083 virtual void SetCut(Text_t *cutarg=""); 00084 00085 00086 ClassDef(TNNFormula,0) 00087 00088 }; 00089 00091 00092 class TNNTree : public TNamed 00093 { 00094 00095 private: 00096 00097 TTree *fTree; // Tree 00098 Int_t fNTrees; // number of TTrees in fTree 00099 TFile *fFile; // fTree File 00100 TNNFormula fFormula; // Input Formula 00101 TNNFormula fOutFormula; // Output Formula 00102 Text_t **fInfos; 00103 Text_t *fFName; 00104 Float_t *fInput; 00105 Float_t *fOutput; 00106 Int_t fNInput; // number of input values 00107 Int_t fNOutput; // number of output values 00108 00109 void CheckRange(Int_t *begin, Int_t *end, Int_t indexMax); 00110 Int_t RMBlanks(Text_t *str); 00111 void CreateTree(); 00112 Int_t NumberOut(Text_t *ttext); 00113 void Decode(Text_t *ttext); 00114 00115 virtual void RefreshInFormula(){fFormula.Refresh();}; 00116 virtual void RefreshOutFormula(){fOutFormula.Refresh();}; 00117 00118 public: 00119 TNNTree() 00120 {fTree=0;fNTrees=0;fFile=0;fInfos=0;fFName=0;fInput=0;fOutput=0;fNInput=0;fNOutput=0;}; 00121 TNNTree(Text_t *name); 00122 virtual ~TNNTree(); 00123 virtual void AddTree(TTree *tree, Int_t begin=0, Int_t end=1000000); 00124 virtual void AddTree(TTree *tree, Text_t *out, Int_t begin=0, Int_t end=1000000); 00125 virtual void Infos(); 00126 virtual TTree* GetTree(){return fTree;}; 00127 virtual void SetFormulaTree(TTree *tree) 00128 { 00129 fFormula.SetTree(tree); 00130 fOutFormula.SetTree(tree); 00131 }; 00132 virtual void SetInFormula(Text_t *formula){fFormula.SetFormula(formula);}; 00133 virtual void SetOutFormula(Text_t *formula){fOutFormula.SetFormula(formula);}; 00134 virtual void SetCut(Text_t *cut=""){fFormula.SetCut(cut);}; 00135 virtual void SetClip(Bool_t trueForClip=1){fFormula.SetClip(trueForClip);}; 00136 virtual void SetOutClip(Bool_t trueForClip=1){fOutFormula.SetClip(trueForClip);}; 00137 virtual void GetEvent(Float_t *input, Float_t *output, Int_t iEvent=0); 00138 virtual Int_t GetNInput(){return fNInput;}; 00139 virtual Int_t GetNOutput(){return fNOutput;}; 00140 virtual void SetFile(Text_t *namearg); 00141 virtual void DeleteTree(); 00142 00143 ClassDef(TNNTree,0) 00144 00145 }; 00146 00148 00149 class TNNControlE : public TCanvas 00150 { 00151 00152 private: 00153 Float_t * fXT; 00154 Float_t * fYT; 00155 Float_t * fXV; 00156 Float_t * fYV; 00157 Int_t fNT; // number of components of train array 00158 Int_t fNV; // number of components of valid array 00159 TGraph *fGraphV; // graph for train 00160 TGraph *fGraphT; // graph for valid 00161 00162 00163 public: 00164 TNNControlE(); 00165 virtual ~TNNControlE(); // destructor 00166 virtual void AddTP(Int_t n, Float_t e); 00167 virtual void AddVP(Int_t n, Float_t e); 00168 virtual void UpdateG(); 00169 virtual void ResetT() 00170 {delete [] fXT;delete [] fYT;fXT=new Float_t[50];fYT=new Float_t[50];fNT=0;}; 00171 virtual void ResetV() 00172 {delete [] fXV;delete [] fYV;fXV=new Float_t[50];fYV=new Float_t[50];fNV=0;}; 00173 virtual Float_t* GetXT(){return fXT;}; 00174 virtual Float_t* GetYT(){return fYT;}; 00175 virtual Float_t* GetXV(){return fXV;}; 00176 virtual Float_t* GetYV(){return fYV;}; 00177 virtual Int_t GetNT(){return fNT;}; 00178 virtual Int_t GetNV(){return fNV;}; 00179 virtual void DrawT(Text_t *text, Float_t x, Float_t y, Float_t angle=0., Int_t color=1) 00180 { 00181 TText *tText= new TText(x,y,text); 00182 tText->SetNDC(kTRUE); 00183 tText->SetTextColor(color); 00184 tText->SetTextAngle(angle); 00185 tText->Draw(); 00186 } 00187 00188 ClassDef(TNNControlE,0) 00189 00190 }; 00191 00193 00194 class TNNKernel : public TNamed 00195 { 00196 00197 private: 00198 Int_t fNHiddL; // number of hidden layers 00199 Float_t **fValues; 00200 Double_t **fErrors; 00201 Double_t **fBiases; 00202 Int_t *fNUnits; 00203 Double_t ***fW; 00204 00205 Int_t fNTrainEvents; // number of events for training 00206 Int_t fNValidEvents; // number of events for validation 00207 TNNTree *fValidTree; // validation tree 00208 Double_t fLearnParam; // learning parameter 00209 Float_t fLowerInitWeight; // minimum weight for initialisation 00210 Float_t fUpperInitWeight; // maximum weight for initialisation 00211 Float_t **fArrayOut; 00212 Float_t *fTeach; 00213 Float_t **fArrayIn; 00214 Int_t *fEventsList; 00215 Int_t fNTrainCycles; // Number of training cycles done 00216 Double_t fUseBiases; // flag for use of biases or not (1=use, 0=no use) 00217 TRandom fRandom; // Random object used in initialisation and mixing 00218 Int_t fNWeights; // number of weights in neural network 00219 Double_t fMu; // backpropagation momentum parameter 00220 Double_t fFlatSE; // Flat Spot elimination paramater 00221 Double_t ***fDW; 00222 Double_t **fDB; 00223 00224 void GetArrayEvt(Int_t iEvent) 00225 { 00226 Int_t l; 00227 for(l=0;l<fNUnits[0];l++)fValues[0][l]=fArrayIn[iEvent][l]; 00228 for(l=0;l<fNUnits[fNHiddL+1];l++)fTeach[l]=fArrayOut[iEvent][l]; 00229 }; 00230 void LearnBackward(); // gradient retropropagation (updates of biases and weights) 00231 void Forward(); // do a simple forward propagation 00232 Double_t Error();// compute the error between forward propagation and teaching 00233 void Error(const char*,const char*,...) const {;}//WarnOff 00234 Double_t ErrorO();// compute the error between forward propagation and teaching 00235 void FreeVW(); 00236 void AllocateVW(Int_t nInput, Text_t *hidden, Int_t nOutput); 00237 void SetHidden(Text_t *ttext); 00238 Float_t Alea(); 00239 void DeleteArray(); 00240 00241 protected: 00242 virtual Double_t Sigmoide(Double_t x) 00243 { 00244 if(x> 10.) return 0.99999; // probability MUST be < 1 00245 if(x<-10.) return 0.; 00246 return (1./(1.+exp(-x))); 00247 }; 00248 virtual Double_t SigPrim(Double_t x){return (x*(1.-x));}; 00249 00250 public: 00251 TNNKernel(); 00252 TNNKernel(Text_t *name, Int_t nInput=5, Text_t *hidden="6:7:8", Int_t nOutput=4); 00253 virtual ~TNNKernel(); // destructor 00254 virtual void SetKernel(Int_t nInput, Text_t *hidden, Int_t nOutput); 00255 virtual void SetLearnParam(Double_t learnParam=0.2,Double_t fse=0.,Double_t mu=0.); 00256 virtual void SetInitParam(Float_t lowerInitWeight=-1., Float_t upperInitWeight=1.); 00257 virtual void Init(); // init biases and weights 00258 virtual void PrintS(); // print structure of network 00259 virtual void Mix(); // mix the events before learning 00260 virtual Double_t TrainOneCycle(); // one loop on internal events = one cycle 00261 virtual void ResetCycles(){fNTrainCycles=0;}; 00262 virtual void Export(Text_t *fileName="exportNN.dat"); 00263 virtual void Import(Text_t *fileName="exportNN.dat"); 00264 virtual void SetUseBiases(Bool_t trueForUse=1){fUseBiases=(Double_t)trueForUse;}; 00265 virtual void SetRandomSeed(UInt_t seed=0){fRandom.SetSeed(seed);}; 00266 virtual UInt_t GetRandomSeed(){return fRandom.GetSeed();}; 00267 virtual Bool_t IsTrained(){return fNTrainCycles;}; 00268 virtual Int_t GetNTrainCycles(){return fNTrainCycles;}; 00269 virtual Int_t GetNTrainEvents(){return fNTrainEvents;}; 00270 virtual Int_t GetNValidEvents(){return fNValidEvents;}; 00271 virtual void SetArraySize(Int_t s=0); 00272 virtual void Fill(Int_t iev=0) 00273 { 00274 Int_t i; 00275 for(i=0;i<fNUnits[0];i++)fArrayIn[iev][i]=fValues[0][i]; 00276 for(i=0;i<fNUnits[fNHiddL+1];i++)fArrayOut[iev][i]=fTeach[i]; 00277 } 00278 virtual Float_t* GetInputAdr(){return fValues[0];}; 00279 virtual void SetInput(Float_t v,Int_t i){fValues[0][i]=v;}; 00280 virtual Int_t GetNInput(){return fNUnits[0];}; 00281 virtual Int_t GetNOutput(){return fNUnits[fNHiddL+1];}; 00282 virtual Float_t GetOutput(Int_t unit=0){return fValues[fNHiddL+1][unit];}; 00283 virtual Float_t* GetOutputAdr(){return fValues[fNHiddL+1];}; 00284 virtual Float_t* GetTeachAdr(){return fTeach;}; 00285 virtual void SetTeach(Float_t v,Int_t i){fTeach[i]=v;}; 00286 virtual Double_t GoThrough(){Forward();return ErrorO();}; 00287 virtual Float_t GetSumO() 00288 { 00289 Int_t i; Float_t s=0.; 00290 for(i=0;i<fNUnits[fNHiddL+1];i++)s+=fValues[fNHiddL+1][i]; 00291 return s; 00292 }; 00293 virtual void SetTrainTree(TNNTree *t); 00294 virtual void SetValidTree(TNNTree *t); 00295 virtual Double_t Valid(); 00296 virtual void TrainNCycles(TNNControlE *conte, Int_t period=5, Int_t nCycles=10); 00297 virtual Int_t GetNWeights() 00298 { 00299 if(!fNUnits)return 0; 00300 Int_t n=0; 00301 for(Int_t i=0;i<fNHiddL+1;i++) 00302 { 00303 n+=fNUnits[i]*fNUnits[i+1]; 00304 } 00305 return n; 00306 }; 00307 00308 ClassDef(TNNKernel,0) 00309 00310 }; 00311 00313 00314 00315 class TNNUtils : public TNamed 00316 { 00317 00318 private: 00319 TTree *fT; // TTree associated 00320 TNNKernel *fK; // associated kernel 00321 Text_t fFName[400]; // file name for the new branch 00322 TFile fTF; // Tfile for the new branch 00323 TBranch *fB; // new branch 00324 UInt_t fOAdr; // adress of output units of the kernel 00325 Int_t fNOut; // number of output units of the kernel 00326 TNNFormula fForm;// formula to compute output 00327 00328 Int_t UpdateNewBranch(); 00329 00330 public: 00331 TNNUtils() {fK=0;fNOut=0;fT=0;fB=0;}; 00332 virtual ~TNNUtils(); // destructor 00333 virtual void SetTree(TTree *t){fT=t;}; 00334 virtual void SetKernel(TNNKernel *k){fK=k;}; 00335 virtual void SetNewBranchFile(Text_t *fname){strcpy(fFName,fname);}; 00336 virtual void SetFormula(Text_t *form, Bool_t clip=1) 00337 { 00338 if(!fT){printf("no tree associated!\n");return;} 00339 fForm.SetTree(fT);fForm.SetFormula(form);fForm.SetClip(clip); 00340 }; 00341 virtual Int_t FillNB(); 00342 virtual TH1F* HIntegral(TH1F *hOrig, Int_t efficiency=1, Text_t *name="Integral", Text_t *title="Integral"); 00343 virtual TGraph* XY(TH1F *hX, TH1F *hY, Int_t color=1); 00344 ClassDef(TNNUtils,0) 00345 00346 }; 00347 00349

Generated on Sun Mar 15 04:54:21 2009 for StRoot by doxygen 1.3.7