- Integrated the file import into the main interface (not as a plugin anymore)
[mldemos:mldemos.git] / Core / datasetManager.h
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public License,\r
8 version 3 as published by the Free Software Foundation.\r
9 \r
10 This library is distributed in the hope that it will be useful, but\r
11 WITHOUT ANY WARRANTY; without even the implied warranty of\r
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
13 Lesser General Public License for more details.\r
14 \r
15 You should have received a copy of the GNU Lesser General Public\r
16 License along with this library; if not, write to the Free\r
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
18 *********************************************************************/\r
19 #ifndef _DATASET_MANAGER_H_\r
20 #define _DATASET_MANAGER_H_\r
21 \r
22 #include <vector>\r
23 #include "public.h"\r
24 \r
25 enum DatasetManagerFlags\r
26 {\r
27         _UNUSED = 0x0000,\r
28         _TRAIN  = 0x0001,\r
29         _VALID  = 0x0010,\r
30         _TEST   = 0x0100,\r
31         _TRAJ   = 0x1000,\r
32         _OBST  = 0x10000,\r
33         _TIME = 0x100000\r
34 };\r
35 typedef DatasetManagerFlags dsmFlags;\r
36 \r
37 struct Obstacle\r
38 {\r
39         fvec axes;                      //the obstacle major axes\r
40         fvec center;            //the center of the obstacle\r
41         float angle;            //the orientation matrix\r
42         fvec power;                     //Gamma is \sum( (x/a)^m )\r
43         fvec repulsion; //safety factor\r
44         Obstacle() :angle(0) {\r
45                 axes.resize(2,1.f);\r
46                 center.resize(2,0.f);\r
47                 power.resize(2,1.f);\r
48                 repulsion.resize(2,1.f);\r
49         };\r
50         bool operator==(const Obstacle& o) const {\r
51                 return center == o.center && axes == o.axes && angle == o.angle && power == o.power && repulsion == o.repulsion;\r
52         }\r
53         bool operator!=(const Obstacle& o) const {\r
54                 return center != o.center || axes != o.axes || angle != o.angle || power != o.power || repulsion != o.repulsion;\r
55         }\r
56 };\r
57 \r
58 struct RewardMap\r
59 {\r
60         int dim;\r
61         ivec size; // size of reward array in each dimension\r
62         int length; // size[0]*size[1]*...*size[dim]\r
63         float *rewards;\r
64         fvec lowerBoundary;\r
65         fvec higherBoundary;\r
66         RewardMap():rewards(0), dim(0), length(0){}\r
67         ~RewardMap(){if(rewards) delete [] rewards;}\r
68         RewardMap& operator= (const RewardMap& r);\r
69 \r
70         void SetReward(float *rewards, ivec size, fvec lowerBoundary, fvec higherBoundary);\r
71 \r
72         void Clear();\r
73 \r
74         void Zero();\r
75 \r
76         // return the value of the reward function at the coordinates provided\r
77         float ValueAt(fvec sample);\r
78 \r
79         void SetValueAt(fvec sample, float value);\r
80 \r
81         void ShiftValueAt(fvec sample, float shift);\r
82 \r
83         void ShiftValueAt(fvec sample, float radius, float shift);\r
84 \r
85 };\r
86 \r
87 struct TimeSerie\r
88 {\r
89         std::string name; // name of the current graph line\r
90         std::vector<long int> timestamps; // time stamps for each frame\r
91         std::vector<fvec> data; // each vector element is a frame\r
92         TimeSerie(std::string name="", std::vector<long int> timestamps=std::vector<long int>(), std::vector<fvec> data=std::vector<fvec>()) : name(name), timestamps(timestamps), data(data){};\r
93         bool operator==(const TimeSerie& t) const {\r
94                 if(name != t.name || timestamps.size() != t.timestamps.size() || data.size() != t.data.size()) return false;\r
95                 for(int i=0; i<timestamps.size(); i++) if(timestamps[i] != t.timestamps[i]) return false;\r
96                 for(int i=0; i<data.size(); i++) if(data[i] != t.data[i]) return false;\r
97                 return true;\r
98         }\r
99         TimeSerie& operator= (const TimeSerie& t)\r
100         {\r
101                 if (this != &t) {\r
102                         name = t.name;\r
103                         timestamps = t.timestamps;\r
104                         data = t.data;\r
105                 }\r
106                 return *this;\r
107         }\r
108         fvec& operator[] (unsigned int i){return data[i];}\r
109         fvec& operator() (unsigned int i){return data[i];}\r
110         void clear(){data.clear();timestamps.clear();}\r
111         size_t size(){return data.size();}\r
112         std::vector<fvec>::iterator begin(){return data.begin();}\r
113         std::vector<fvec>::iterator end(){return data.end();}\r
114 \r
115         TimeSerie& operator+=(const TimeSerie& t) {\r
116                 data.insert(data.end(), t.data.begin(), t.data.end());\r
117                 int count = timestamps.size();\r
118                 int lastTimestamp = timestamps.back();\r
119                 timestamps.insert(timestamps.end(), t.timestamps.begin(), t.timestamps.end());\r
120                 for(int i=count; i < timestamps.size(); i++) timestamps[i] += lastTimestamp;\r
121                 return *this;}\r
122         TimeSerie operator+(const TimeSerie& t) const {TimeSerie a = *this; a+=t; return a;}\r
123         TimeSerie& operator<< (const TimeSerie& t) {return *this += t;};\r
124 \r
125         TimeSerie& operator+=(const fvec& v) {data.push_back(v); timestamps.push_back(timestamps.back()+1); return *this;}\r
126         TimeSerie operator+(const fvec& v) const {TimeSerie a = *this; a+=v; return a;}\r
127         TimeSerie& operator<< (const fvec& v) {return *this += v;};\r
128 };\r
129 \r
130 class DatasetManager\r
131 {\r
132 protected:\r
133         static u32 IDCount;\r
134 \r
135         u32 ID;\r
136 \r
137         int size; // the samples size (dimension)\r
138 \r
139         std::vector< fvec > samples;\r
140 \r
141         std::vector< ipair > sequences;\r
142 \r
143         std::vector<dsmFlags> flags;\r
144 \r
145         std::vector<Obstacle> obstacles;\r
146 \r
147         std::vector<TimeSerie> series;\r
148 \r
149         RewardMap rewards;\r
150 \r
151         ivec labels;\r
152 \r
153         u32 *perm;\r
154 \r
155 public:\r
156     bool bProjected;\r
157 \r
158 public:\r
159         DatasetManager(int dimension = 2);\r
160         ~DatasetManager();\r
161 \r
162         void Randomize(int seed=-1);\r
163         void Clear();\r
164         double Compare(fvec sample);\r
165 \r
166         int GetSize(){return size;}\r
167         int GetCount(){return samples.size();}\r
168         int GetDimCount();\r
169     std::pair<fvec, fvec> GetBounds();\r
170         static u32 GetClassCount(ivec classes);\r
171 \r
172         // functions to manage samples\r
173         void AddSample(fvec sample, int label = 0, dsmFlags flag = _UNUSED);\r
174         void AddSamples(std::vector< fvec > samples, ivec newLabels=ivec(), std::vector<dsmFlags> newFlags=std::vector<dsmFlags>());\r
175         void AddSamples(DatasetManager &newSamples);    \r
176         void RemoveSample(unsigned int index);\r
177     void RemoveSamples(ivec indices);\r
178 \r
179     fvec GetSample(int index=0){ return (index < samples.size()) ? samples[index] : fvec(); }\r
180     fvec GetSampleDim(int index, ivec inputDims, int outputDim=-1);\r
181     std::vector< fvec > GetSamples(){return samples;}\r
182         std::vector< fvec > GetSamples(u32 count, dsmFlags flag=_UNUSED, dsmFlags replaceWith=_TRAIN);\r
183     std::vector< fvec > GetSampleDims(ivec inputDims, int outputDim=-1);\r
184         void SetSample(int index, fvec sample);\r
185     void SetSamples(std::vector<fvec> samples){this->samples = samples;}\r
186 \r
187         int GetLabel(int index){return index < labels.size() ? labels[index] : 0;}\r
188         ivec GetLabels(){return labels;}\r
189         void SetLabel(int index, int label){if(index<labels.size())labels[index] = label;}\r
190     void SetLabels(ivec labels){this->labels = labels;}\r
191 \r
192         // functions to manage sequences\r
193         void AddSequence(int start, int stop);\r
194         void AddSequence(ipair newSequence);\r
195         void AddSequences(std::vector< ipair > newSequences);\r
196         void RemoveSequence(unsigned int index);\r
197 \r
198         ipair GetSequence(unsigned int index){return index < sequences.size() ? sequences[index] : ipair(-1,-1);}\r
199         std::vector< ipair > GetSequences(){return sequences;}\r
200         std::vector< std::vector<fvec> > GetTrajectories(int resampleType, int resampleCount, int centerType, float dT, int zeroEnding);\r
201 \r
202         // functions to manage obstacles\r
203         void AddObstacle(Obstacle o){obstacles.push_back(o);}\r
204         void AddObstacle(fvec center, fvec axes, float angle, fvec power, fvec repulsion);\r
205         void AddObstacles(std::vector<Obstacle> newObstacles);\r
206         void RemoveObstacle(unsigned int index);\r
207         std::vector< Obstacle > GetObstacles(){return obstacles;}\r
208         Obstacle GetObstacle(unsigned int index){return index < obstacles.size() ? obstacles[index] : Obstacle();}\r
209 \r
210         // functions to manage rewards\r
211         void AddReward(float *values, ivec size, fvec lowerBoundary, fvec higherBoundary);\r
212         RewardMap *GetReward(){return &rewards;}\r
213 \r
214         // functions to manage time series\r
215         void AddTimeSerie(std::string name, std::vector<fvec> data, std::vector<long int> timestamps=std::vector<long int>());\r
216         void AddTimeSerie(TimeSerie serie);\r
217         void AddTimeSeries(std::vector< TimeSerie > newTimeSeries);\r
218         void RemoveTimeSerie(unsigned int index);\r
219     std::vector<TimeSerie>& GetTimeSeries(){return series;}\r
220 \r
221         // functions to manage flags\r
222         dsmFlags GetFlag(int index){return index < flags.size() ? flags[index] : _UNUSED;}\r
223         void SetFlag(int index, dsmFlags flag){if(index < flags.size()) flags[index] = flag;}\r
224         std::vector<dsmFlags> GetFlags(){return flags;}\r
225         std::vector<bool> GetFreeFlags();\r
226         void ResetFlags();\r
227 \r
228         void Save(const char *filename);\r
229         bool Load(const char *filename);\r
230 \r
231 };\r
232 \r
233 #endif // _DATASET_MANAGER_H_\r