StEmcUtil/neuralNet/NeuNet.h
00001
00002
00003
00004
00005
00006
00007
00008
00009
00011
00013 #ifndef ROOT_TNamed
00014 #include "TNamed.h"
00015 #endif
00016 #ifndef ROOT_TROOT
00017 #include "TROOT.h"
00018 #endif
00019 #ifndef ROOT_TTree
00020 #include "TTree.h"
00021 #endif
00022 #ifndef ROOT_TString
00023 #include "TString.h"
00024 #endif
00025 #ifndef ROOT_TFormula
00026 #include "TFormula.h"
00027 #endif
00028 #ifndef ROOT_TTreeFormula
00029 #include "TTreeFormula.h"
00030 #endif
00031 #ifndef ROOT_TCanvas
00032 #include "TCanvas.h"
00033 #endif
00034 #include "TFrame.h"
00035 #include "TStringLong.h"
00036 #include "TH1.h"
00037 #include "TGraph.h"
00038 #include "TAxis.h"
00039 #include "TFile.h"
00040 #include "TText.h"
00041 #include "TDatime.h"
00042 #include "TRandom.h"
00043 #include "TPad.h"
00044 #include "math.h"
00045 #include "stdlib.h"
00046
00047
00049
00050 class TNNFormula : public TNamed
00051 {
00052
00053 private:
00054 Int_t fNValues;
00055 TTree *fTree;
00056 TTreeFormula *fTTCut;
00057 TTreeFormula **fTTFormula;
00058 Bool_t fClip;
00059 TStringLong* fFormula;
00060 TStringLong* fCut;
00061 Bool_t fRefresh;
00062 Int_t RMBlanks(Text_t *str);
00063 Float_t Clip(Float_t x);
00064
00065 public:
00066 TNNFormula(){fNValues=0;fClip=1;fTree=0;fTTCut=0;fTTFormula=0;fFormula=0;fCut=0;fRefresh=0;};
00067 TNNFormula(Text_t *name, Text_t *formula, Text_t *cut, TTree *tree);
00068 virtual ~TNNFormula();
00069 virtual Bool_t Find(Int_t iEvent, Float_t *values);
00070 virtual void Find(Int_t iEvent=0);
00071 virtual Int_t GetNValues(){return fNValues;};
00072 virtual void SetTree(TTree *tree);
00073 virtual void SetFormula(Text_t *formula);
00074 virtual void Refresh()
00075 {
00076 fRefresh=1;
00077 if(fFormula)SetFormula((Text_t*)(fFormula->Data()));
00078 if(fCut)SetCut((Text_t*)(fCut->Data()));
00079 fRefresh=0;
00080 };
00081 virtual Int_t Length(){if(fFormula)return fFormula->Length();else return 0;};
00082 virtual void SetClip(Bool_t trueForClip=1);
00083 virtual void SetCut(Text_t *cutarg="");
00084
00085
00086 ClassDef(TNNFormula,0)
00087
00088 };
00089
00091
00092 class TNNTree : public TNamed
00093 {
00094
00095 private:
00096
00097 TTree *fTree;
00098 Int_t fNTrees;
00099 TFile *fFile;
00100 TNNFormula fFormula;
00101 TNNFormula fOutFormula;
00102 Text_t **fInfos;
00103 Text_t *fFName;
00104 Float_t *fInput;
00105 Float_t *fOutput;
00106 Int_t fNInput;
00107 Int_t fNOutput;
00108
00109 void CheckRange(Int_t *begin, Int_t *end, Int_t indexMax);
00110 Int_t RMBlanks(Text_t *str);
00111 void CreateTree();
00112 Int_t NumberOut(Text_t *ttext);
00113 void Decode(Text_t *ttext);
00114
00115 virtual void RefreshInFormula(){fFormula.Refresh();};
00116 virtual void RefreshOutFormula(){fOutFormula.Refresh();};
00117
00118 public:
00119 TNNTree()
00120 {fTree=0;fNTrees=0;fFile=0;fInfos=0;fFName=0;fInput=0;fOutput=0;fNInput=0;fNOutput=0;};
00121 TNNTree(Text_t *name);
00122 virtual ~TNNTree();
00123 virtual void AddTree(TTree *tree, Int_t begin=0, Int_t end=1000000);
00124 virtual void AddTree(TTree *tree, Text_t *out, Int_t begin=0, Int_t end=1000000);
00125 virtual void Infos();
00126 virtual TTree* GetTree(){return fTree;};
00127 virtual void SetFormulaTree(TTree *tree)
00128 {
00129 fFormula.SetTree(tree);
00130 fOutFormula.SetTree(tree);
00131 };
00132 virtual void SetInFormula(Text_t *formula){fFormula.SetFormula(formula);};
00133 virtual void SetOutFormula(Text_t *formula){fOutFormula.SetFormula(formula);};
00134 virtual void SetCut(Text_t *cut=""){fFormula.SetCut(cut);};
00135 virtual void SetClip(Bool_t trueForClip=1){fFormula.SetClip(trueForClip);};
00136 virtual void SetOutClip(Bool_t trueForClip=1){fOutFormula.SetClip(trueForClip);};
00137 virtual void GetEvent(Float_t *input, Float_t *output, Int_t iEvent=0);
00138 virtual Int_t GetNInput(){return fNInput;};
00139 virtual Int_t GetNOutput(){return fNOutput;};
00140 virtual void SetFile(Text_t *namearg);
00141 virtual void DeleteTree();
00142
00143 ClassDef(TNNTree,0)
00144
00145 };
00146
00148
00149 class TNNControlE : public TCanvas
00150 {
00151
00152 private:
00153 Float_t * fXT;
00154 Float_t * fYT;
00155 Float_t * fXV;
00156 Float_t * fYV;
00157 Int_t fNT;
00158 Int_t fNV;
00159 TGraph *fGraphV;
00160 TGraph *fGraphT;
00161
00162
00163 public:
00164 TNNControlE();
00165 virtual ~TNNControlE();
00166 virtual void AddTP(Int_t n, Float_t e);
00167 virtual void AddVP(Int_t n, Float_t e);
00168 virtual void UpdateG();
00169 virtual void ResetT()
00170 {delete [] fXT;delete [] fYT;fXT=new Float_t[50];fYT=new Float_t[50];fNT=0;};
00171 virtual void ResetV()
00172 {delete [] fXV;delete [] fYV;fXV=new Float_t[50];fYV=new Float_t[50];fNV=0;};
00173 virtual Float_t* GetXT(){return fXT;};
00174 virtual Float_t* GetYT(){return fYT;};
00175 virtual Float_t* GetXV(){return fXV;};
00176 virtual Float_t* GetYV(){return fYV;};
00177 virtual Int_t GetNT(){return fNT;};
00178 virtual Int_t GetNV(){return fNV;};
00179 virtual void DrawT(Text_t *text, Float_t x, Float_t y, Float_t angle=0., Int_t color=1)
00180 {
00181 TText *tText= new TText(x,y,text);
00182 tText->SetNDC(kTRUE);
00183 tText->SetTextColor(color);
00184 tText->SetTextAngle(angle);
00185 tText->Draw();
00186 }
00187
00188 ClassDef(TNNControlE,0)
00189
00190 };
00191
00193
00194 class TNNKernel : public TNamed
00195 {
00196
00197 private:
00198 Int_t fNHiddL;
00199 Float_t **fValues;
00200 Double_t **fErrors;
00201 Double_t **fBiases;
00202 Int_t *fNUnits;
00203 Double_t ***fW;
00204
00205 Int_t fNTrainEvents;
00206 Int_t fNValidEvents;
00207 TNNTree *fValidTree;
00208 Double_t fLearnParam;
00209 Float_t fLowerInitWeight;
00210 Float_t fUpperInitWeight;
00211 Float_t **fArrayOut;
00212 Float_t *fTeach;
00213 Float_t **fArrayIn;
00214 Int_t *fEventsList;
00215 Int_t fNTrainCycles;
00216 Double_t fUseBiases;
00217 TRandom fRandom;
00218 Int_t fNWeights;
00219 Double_t fMu;
00220 Double_t fFlatSE;
00221 Double_t ***fDW;
00222 Double_t **fDB;
00223
00224 void GetArrayEvt(Int_t iEvent)
00225 {
00226 Int_t l;
00227 for(l=0;l<fNUnits[0];l++)fValues[0][l]=fArrayIn[iEvent][l];
00228 for(l=0;l<fNUnits[fNHiddL+1];l++)fTeach[l]=fArrayOut[iEvent][l];
00229 };
00230 void LearnBackward();
00231 void Forward();
00232 Double_t Error();
00233 void Error(const char*,const char*,...) const {;}
00234 Double_t ErrorO();
00235 void FreeVW();
00236 void AllocateVW(Int_t nInput, Text_t *hidden, Int_t nOutput);
00237 void SetHidden(Text_t *ttext);
00238 Float_t Alea();
00239 void DeleteArray();
00240
00241 protected:
00242 virtual Double_t Sigmoide(Double_t x)
00243 {
00244 if(x> 10.) return 0.99999;
00245 if(x<-10.) return 0.;
00246 return (1./(1.+exp(-x)));
00247 };
00248 virtual Double_t SigPrim(Double_t x){return (x*(1.-x));};
00249
00250 public:
00251 TNNKernel();
00252 TNNKernel(Text_t *name, Int_t nInput=5, Text_t *hidden="6:7:8", Int_t nOutput=4);
00253 virtual ~TNNKernel();
00254 virtual void SetKernel(Int_t nInput, Text_t *hidden, Int_t nOutput);
00255 virtual void SetLearnParam(Double_t learnParam=0.2,Double_t fse=0.,Double_t mu=0.);
00256 virtual void SetInitParam(Float_t lowerInitWeight=-1., Float_t upperInitWeight=1.);
00257 virtual void Init();
00258 virtual void PrintS();
00259 virtual void Mix();
00260 virtual Double_t TrainOneCycle();
00261 virtual void ResetCycles(){fNTrainCycles=0;};
00262 virtual void Export(Text_t *fileName="exportNN.dat");
00263 virtual void Import(Text_t *fileName="exportNN.dat");
00264 virtual void SetUseBiases(Bool_t trueForUse=1){fUseBiases=(Double_t)trueForUse;};
00265 virtual void SetRandomSeed(UInt_t seed=0){fRandom.SetSeed(seed);};
00266 virtual UInt_t GetRandomSeed(){return fRandom.GetSeed();};
00267 virtual Bool_t IsTrained(){return fNTrainCycles;};
00268 virtual Int_t GetNTrainCycles(){return fNTrainCycles;};
00269 virtual Int_t GetNTrainEvents(){return fNTrainEvents;};
00270 virtual Int_t GetNValidEvents(){return fNValidEvents;};
00271 virtual void SetArraySize(Int_t s=0);
00272 virtual void Fill(Int_t iev=0)
00273 {
00274 Int_t i;
00275 for(i=0;i<fNUnits[0];i++)fArrayIn[iev][i]=fValues[0][i];
00276 for(i=0;i<fNUnits[fNHiddL+1];i++)fArrayOut[iev][i]=fTeach[i];
00277 }
00278 virtual Float_t* GetInputAdr(){return fValues[0];};
00279 virtual void SetInput(Float_t v,Int_t i){fValues[0][i]=v;};
00280 virtual Int_t GetNInput(){return fNUnits[0];};
00281 virtual Int_t GetNOutput(){return fNUnits[fNHiddL+1];};
00282 virtual Float_t GetOutput(Int_t unit=0){return fValues[fNHiddL+1][unit];};
00283 virtual Float_t* GetOutputAdr(){return fValues[fNHiddL+1];};
00284 virtual Float_t* GetTeachAdr(){return fTeach;};
00285 virtual void SetTeach(Float_t v,Int_t i){fTeach[i]=v;};
00286 virtual Double_t GoThrough(){Forward();return ErrorO();};
00287 virtual Float_t GetSumO()
00288 {
00289 Int_t i; Float_t s=0.;
00290 for(i=0;i<fNUnits[fNHiddL+1];i++)s+=fValues[fNHiddL+1][i];
00291 return s;
00292 };
00293 virtual void SetTrainTree(TNNTree *t);
00294 virtual void SetValidTree(TNNTree *t);
00295 virtual Double_t Valid();
00296 virtual void TrainNCycles(TNNControlE *conte, Int_t period=5, Int_t nCycles=10);
00297 virtual Int_t GetNWeights()
00298 {
00299 if(!fNUnits)return 0;
00300 Int_t n=0;
00301 for(Int_t i=0;i<fNHiddL+1;i++)
00302 {
00303 n+=fNUnits[i]*fNUnits[i+1];
00304 }
00305 return n;
00306 };
00307
00308 ClassDef(TNNKernel,0)
00309
00310 };
00311
00313
00314
00315 class TNNUtils : public TNamed
00316 {
00317
00318 private:
00319 TTree *fT;
00320 TNNKernel *fK;
00321 Text_t fFName[400];
00322 TFile fTF;
00323 TBranch *fB;
00324 UInt_t fOAdr;
00325 Int_t fNOut;
00326 TNNFormula fForm;
00327
00328 Int_t UpdateNewBranch();
00329
00330 public:
00331 TNNUtils() {fK=0;fNOut=0;fT=0;fB=0;};
00332 virtual ~TNNUtils();
00333 virtual void SetTree(TTree *t){fT=t;};
00334 virtual void SetKernel(TNNKernel *k){fK=k;};
00335 virtual void SetNewBranchFile(Text_t *fname){strcpy(fFName,fname);};
00336 virtual void SetFormula(Text_t *form, Bool_t clip=1)
00337 {
00338 if(!fT){printf("no tree associated!\n");return;}
00339 fForm.SetTree(fT);fForm.SetFormula(form);fForm.SetClip(clip);
00340 };
00341 virtual Int_t FillNB();
00342 virtual TH1F* HIntegral(TH1F *hOrig, Int_t efficiency=1, Text_t *name="Integral", Text_t *title="Integral");
00343 virtual TGraph* XY(TH1F *hX, TH1F *hY, Int_t color=1);
00344 ClassDef(TNNUtils,0)
00345
00346 };
00347
00349
Generated on Sun Mar 15 04:54:21 2009 for StRoot by
1.3.7