CHANGED: the base zoom factors to match 2D and 3D zooms
[mldemos:allopens-mldemos.git] / Core / datasetManager.h
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public License,\r
8 version 3 as published by the Free Software Foundation.\r
9 \r
10 This library is distributed in the hope that it will be useful, but\r
11 WITHOUT ANY WARRANTY; without even the implied warranty of\r
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
13 Lesser General Public License for more details.\r
14 \r
15 You should have received a copy of the GNU Lesser General Public\r
16 License along with this library; if not, write to the Free\r
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
18 *********************************************************************/\r
19 #ifndef _DATASET_MANAGER_H_\r
20 #define _DATASET_MANAGER_H_\r
21 \r
22 #include <vector>\r
23 #include "public.h"\r
24 #include <string.h>\r
25 \r
26 enum DatasetManagerFlags\r
27 {\r
28         _UNUSED = 0x0000,\r
29         _TRAIN  = 0x0001,\r
30         _VALID  = 0x0010,\r
31         _TEST   = 0x0100,\r
32         _TRAJ   = 0x1000,\r
33         _OBST  = 0x10000,\r
34         _TIME = 0x100000\r
35 };\r
36 typedef DatasetManagerFlags dsmFlags;\r
37 \r
38 struct Obstacle\r
39 {\r
40         fvec axes;                      //the obstacle major axes\r
41         fvec center;            //the center of the obstacle\r
42         float angle;            //the orientation matrix\r
43         fvec power;                     //Gamma is \sum( (x/a)^m )\r
44         fvec repulsion; //safety factor\r
45         Obstacle() :angle(0) {\r
46                 axes.resize(2,1.f);\r
47                 center.resize(2,0.f);\r
48                 power.resize(2,1.f);\r
49                 repulsion.resize(2,1.f);\r
50         };\r
51         bool operator==(const Obstacle& o) const {\r
52                 return center == o.center && axes == o.axes && angle == o.angle && power == o.power && repulsion == o.repulsion;\r
53         }\r
54         bool operator!=(const Obstacle& o) const {\r
55                 return center != o.center || axes != o.axes || angle != o.angle || power != o.power || repulsion != o.repulsion;\r
56         }\r
57 };\r
58 \r
59 struct RewardMap\r
60 {\r
61         int dim;\r
62         ivec size; // size of reward array in each dimension\r
63         int length; // size[0]*size[1]*...*size[dim]\r
64     double *rewards;\r
65         fvec lowerBoundary;\r
66         fvec higherBoundary;\r
67         RewardMap():rewards(0), dim(0), length(0){}\r
68     ~RewardMap(){if(rewards) delete [] rewards; rewards=0;}\r
69         RewardMap& operator= (const RewardMap& r);\r
70 \r
71     bool Empty() const {return length==0;}\r
72 \r
73     void SetReward(const double *rewards, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
74 \r
75     void SetReward(const float *rewards, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
76 \r
77         void Clear();\r
78 \r
79         void Zero();\r
80 \r
81         // return the value of the reward function at the coordinates provided\r
82     float ValueAt(fvec sample) const ;\r
83 \r
84     float *GetRewardFloat() const ;\r
85 \r
86     void SetValueAt(const fvec sample, const double value);\r
87 \r
88     void ShiftValueAt(const fvec sample, const double shift);\r
89 \r
90     void ShiftValueAt(const fvec sample, const double radius, const double shift);\r
91 };\r
92 \r
93 struct TimeSerie\r
94 {\r
95         std::string name; // name of the current graph line\r
96         std::vector<long int> timestamps; // time stamps for each frame\r
97         std::vector<fvec> data; // each vector element is a frame\r
98     TimeSerie(std::string name="", std::vector<long int> timestamps=std::vector<long int>(), std::vector<fvec> data=std::vector<fvec>()) : name(name), timestamps(timestamps), data(data){}\r
99         bool operator==(const TimeSerie& t) const {\r
100                 if(name != t.name || timestamps.size() != t.timestamps.size() || data.size() != t.data.size()) return false;\r
101                 for(int i=0; i<timestamps.size(); i++) if(timestamps[i] != t.timestamps[i]) return false;\r
102                 for(int i=0; i<data.size(); i++) if(data[i] != t.data[i]) return false;\r
103                 return true;\r
104         }\r
105         TimeSerie& operator= (const TimeSerie& t)\r
106         {\r
107                 if (this != &t) {\r
108                         name = t.name;\r
109                         timestamps = t.timestamps;\r
110                         data = t.data;\r
111                 }\r
112                 return *this;\r
113         }\r
114     fvec& operator[] (const unsigned int i) const {return data[i];}\r
115     fvec& operator() (const unsigned int i) const {return data[i];}\r
116         void clear(){data.clear();timestamps.clear();}\r
117     size_t size() const {return data.size();}\r
118     std::vector<fvec>::iterator begin() const {return data.begin();}\r
119     std::vector<fvec>::iterator end() const {return data.end();}\r
120 \r
121         TimeSerie& operator+=(const TimeSerie& t) {\r
122                 data.insert(data.end(), t.data.begin(), t.data.end());\r
123                 int count = timestamps.size();\r
124                 int lastTimestamp = timestamps.back();\r
125                 timestamps.insert(timestamps.end(), t.timestamps.begin(), t.timestamps.end());\r
126                 for(int i=count; i < timestamps.size(); i++) timestamps[i] += lastTimestamp;\r
127                 return *this;}\r
128         TimeSerie operator+(const TimeSerie& t) const {TimeSerie a = *this; a+=t; return a;}\r
129     TimeSerie& operator<< (const TimeSerie& t) {return *this += t;}\r
130 \r
131         TimeSerie& operator+=(const fvec& v) {data.push_back(v); timestamps.push_back(timestamps.back()+1); return *this;}\r
132         TimeSerie operator+(const fvec& v) const {TimeSerie a = *this; a+=v; return a;}\r
133     TimeSerie& operator<< (const fvec& v) {return *this += v;}\r
134 };\r
135 \r
136 class DatasetManager\r
137 {\r
138 protected:\r
139         static u32 IDCount;\r
140 \r
141         u32 ID;\r
142 \r
143         int size; // the samples size (dimension)\r
144 \r
145         std::vector< fvec > samples;\r
146 \r
147         std::vector< ipair > sequences;\r
148 \r
149         std::vector<dsmFlags> flags;\r
150 \r
151         std::vector<Obstacle> obstacles;\r
152 \r
153         std::vector<TimeSerie> series;\r
154 \r
155         RewardMap rewards;\r
156 \r
157         ivec labels;\r
158 \r
159         u32 *perm;\r
160 \r
161 public:\r
162     bool bProjected;\r
163     std::map<int, std::vector<std::string> > categorical;\r
164 \r
165 public:\r
166     DatasetManager(const int dimension = 2);\r
167         ~DatasetManager();\r
168 \r
169     void Randomize(const int seed=-1);\r
170         void Clear();\r
171     double Compare(const fvec sample) const;\r
172 \r
173     int GetSize() const {return size;}\r
174     int GetCount() const {return samples.size();}\r
175     int GetDimCount() const;\r
176     std::pair<fvec, fvec> GetBounds() const;\r
177     static u32 GetClassCount(const ivec classes);\r
178 \r
179         // functions to manage samples\r
180     void AddSample(const fvec sample, const int label = 0, const dsmFlags flag = _UNUSED);\r
181     void AddSamples(const std::vector< fvec > samples, const ivec newLabels=ivec(), const std::vector<dsmFlags> newFlags=std::vector<dsmFlags>());\r
182     void AddSamples(const DatasetManager &newSamples);\r
183     void RemoveSample(const unsigned int index);\r
184     void RemoveSamples(const ivec indices);\r
185 \r
186     fvec GetSample(const int index=0) const { return (index < samples.size()) ? samples[index] : fvec(); }\r
187     fvec GetSampleDim(const int index, const ivec inputDims, const int outputDim=-1) const;\r
188     std::vector< fvec > GetSamples() const {return samples;}\r
189     std::vector< fvec > GetSamples(const u32 count, const dsmFlags flag=_UNUSED, const dsmFlags replaceWith=_TRAIN) const;\r
190     std::vector< fvec > GetSampleDims(const ivec inputDims, const int outputDim=-1) const ;\r
191     void SetSample(const int index, const fvec sample);\r
192     void SetSamples(const std::vector<fvec> samples){this->samples = samples;}\r
193 \r
194     int GetLabel(const int index) const {return index < labels.size() ? labels[index] : 0;}\r
195     ivec GetLabels() const {return labels;}\r
196         void SetLabel(int index, int label){if(index<labels.size())labels[index] = label;}\r
197     void SetLabels(ivec labels){this->labels = labels;}\r
198 \r
199     std::string GetCategorical(const int dimension,const  int value) const ;\r
200     bool IsCategorical(const int dimension) const ;\r
201 \r
202         // functions to manage sequences\r
203     void AddSequence(const int start, const int stop);\r
204     void AddSequence(const ipair newSequence);\r
205     void AddSequences(const std::vector< ipair > newSequences);\r
206     void RemoveSequence(const unsigned int index);\r
207 \r
208     ipair const GetSequence(const unsigned int index) const {return index < sequences.size() ? sequences[index] : ipair(-1,-1);}\r
209     std::vector< ipair > GetSequences() const {return sequences;}\r
210     std::vector< std::vector<fvec> > GetTrajectories(const int resampleType, const int resampleCount, const int centerType, const float dT, const int zeroEnding) const ;\r
211 \r
212         // functions to manage obstacles\r
213     void AddObstacle(const Obstacle o){obstacles.push_back(o);}\r
214     void AddObstacle(const fvec center, const fvec axes, const float angle, const fvec power, const fvec repulsion);\r
215     void AddObstacles(const std::vector<Obstacle> newObstacles);\r
216     void RemoveObstacle(const unsigned int index);\r
217     std::vector< Obstacle > GetObstacles() const {return obstacles;}\r
218     Obstacle GetObstacle(const unsigned int index) const {return index < obstacles.size() ? obstacles[index] : Obstacle();}\r
219 \r
220         // functions to manage rewards\r
221     void AddReward(const float *values, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
222     RewardMap *GetReward() const {return &rewards;}\r
223 \r
224         // functions to manage time series\r
225     void AddTimeSerie(const std::string name, const std::vector<fvec> data, const std::vector<long int> timestamps=std::vector<long int>());\r
226     void AddTimeSerie(const TimeSerie serie);\r
227     void AddTimeSeries(const std::vector< TimeSerie > newTimeSeries);\r
228     void RemoveTimeSerie(const unsigned int index);\r
229     std::vector<TimeSerie>& GetTimeSeries() const {return series;}\r
230 \r
231         // functions to manage flags\r
232     dsmFlags GetFlag(const int index) const {return index < flags.size() ? flags[index] : _UNUSED;}\r
233     void SetFlag(const int index, const dsmFlags flag){if(index < flags.size()) flags[index] = flag;}\r
234     std::vector<dsmFlags> GetFlags() const {return flags;}\r
235     std::vector<bool> GetFreeFlags() const ;\r
236         void ResetFlags();\r
237 \r
238     void Save(const char *filename) const ;\r
239         bool Load(const char *filename);\r
240 };\r
241 \r
242 #endif // _DATASET_MANAGER_H_\r