const-fixing functions
[mldemos:baraks-mldemos.git] / Core / datasetManager.h
1 /*********************************************************************\r
2 MLDemos: A User-Friendly visualization toolkit for machine learning\r
3 Copyright (C) 2010  Basilio Noris\r
4 Contact: mldemos@b4silio.com\r
5 \r
6 This library is free software; you can redistribute it and/or\r
7 modify it under the terms of the GNU Lesser General Public License,\r
8 version 3 as published by the Free Software Foundation.\r
9 \r
10 This library is distributed in the hope that it will be useful, but\r
11 WITHOUT ANY WARRANTY; without even the implied warranty of\r
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU\r
13 Lesser General Public License for more details.\r
14 \r
15 You should have received a copy of the GNU Lesser General Public\r
16 License along with this library; if not, write to the Free\r
17 Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.\r
18 *********************************************************************/\r
19 #ifndef _DATASET_MANAGER_H_\r
20 #define _DATASET_MANAGER_H_\r
21 \r
22 #include <vector>\r
23 #include "public.h"\r
24 #include <string.h>\r
25 \r
26 enum DatasetManagerFlags\r
27 {\r
28         _UNUSED = 0x0000,\r
29         _TRAIN  = 0x0001,\r
30         _VALID  = 0x0010,\r
31         _TEST   = 0x0100,\r
32         _TRAJ   = 0x1000,\r
33         _OBST  = 0x10000,\r
34         _TIME = 0x100000\r
35 };\r
36 typedef DatasetManagerFlags dsmFlags;\r
37 \r
38 struct Obstacle\r
39 {\r
40         fvec axes;                      //the obstacle major axes\r
41         fvec center;            //the center of the obstacle\r
42         float angle;            //the orientation matrix\r
43         fvec power;                     //Gamma is \sum( (x/a)^m )\r
44         fvec repulsion; //safety factor\r
45         Obstacle() :angle(0) {\r
46                 axes.resize(2,1.f);\r
47                 center.resize(2,0.f);\r
48                 power.resize(2,1.f);\r
49                 repulsion.resize(2,1.f);\r
50         };\r
51         bool operator==(const Obstacle& o) const {\r
52                 return center == o.center && axes == o.axes && angle == o.angle && power == o.power && repulsion == o.repulsion;\r
53         }\r
54         bool operator!=(const Obstacle& o) const {\r
55                 return center != o.center || axes != o.axes || angle != o.angle || power != o.power || repulsion != o.repulsion;\r
56         }\r
57 };\r
58 \r
59 struct RewardMap\r
60 {\r
61         int dim;\r
62         ivec size; // size of reward array in each dimension\r
63         int length; // size[0]*size[1]*...*size[dim]\r
64     double *rewards;\r
65         fvec lowerBoundary;\r
66         fvec higherBoundary;\r
67         RewardMap():rewards(0), dim(0), length(0){}\r
68     ~RewardMap(){if(rewards) delete [] rewards; rewards=0;}\r
69         RewardMap& operator= (const RewardMap& r);\r
70 \r
71     bool Empty() const {return length==0;}\r
72 \r
73     void SetReward(const double *rewards, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
74 \r
75     void SetReward(const float *rewards, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
76 \r
77         void Clear();\r
78 \r
79         void Zero();\r
80 \r
81         // return the value of the reward function at the coordinates provided\r
82     float ValueAt(fvec sample) const ;\r
83 \r
84     float *GetRewardFloat() const ;\r
85 \r
86     void SetValueAt(const fvec sample, const double value);\r
87 \r
88     void ShiftValueAt(const fvec sample, const double shift);\r
89 \r
90     void ShiftValueAt(const fvec sample, const double radius, const double shift);\r
91 };\r
92 \r
93 struct TimeSerie\r
94 {\r
95         std::string name; // name of the current graph line\r
96         std::vector<long int> timestamps; // time stamps for each frame\r
97         std::vector<fvec> data; // each vector element is a frame\r
98     TimeSerie(std::string name="", std::vector<long int> timestamps=std::vector<long int>(), std::vector<fvec> data=std::vector<fvec>()) : name(name), timestamps(timestamps), data(data){}\r
99         bool operator==(const TimeSerie& t) const {\r
100                 if(name != t.name || timestamps.size() != t.timestamps.size() || data.size() != t.data.size()) return false;\r
101                 for(int i=0; i<timestamps.size(); i++) if(timestamps[i] != t.timestamps[i]) return false;\r
102                 for(int i=0; i<data.size(); i++) if(data[i] != t.data[i]) return false;\r
103                 return true;\r
104         }\r
105         TimeSerie& operator= (const TimeSerie& t)\r
106         {\r
107                 if (this != &t) {\r
108                         name = t.name;\r
109                         timestamps = t.timestamps;\r
110                         data = t.data;\r
111                 }\r
112                 return *this;\r
113         }\r
114     fvec& operator[] (const unsigned int i) {return data[i];}\r
115     fvec& operator() (const unsigned int i) {return data[i];}\r
116     const fvec& at(const unsigned int i) const {return data.at(i);}\r
117     void clear(){data.clear();timestamps.clear();}\r
118     size_t size() const {return data.size();}\r
119     std::vector<fvec>::iterator begin() {return data.begin();}\r
120     std::vector<fvec>::iterator end() {return data.end();}\r
121 \r
122         TimeSerie& operator+=(const TimeSerie& t) {\r
123                 data.insert(data.end(), t.data.begin(), t.data.end());\r
124                 int count = timestamps.size();\r
125                 int lastTimestamp = timestamps.back();\r
126                 timestamps.insert(timestamps.end(), t.timestamps.begin(), t.timestamps.end());\r
127                 for(int i=count; i < timestamps.size(); i++) timestamps[i] += lastTimestamp;\r
128                 return *this;}\r
129         TimeSerie operator+(const TimeSerie& t) const {TimeSerie a = *this; a+=t; return a;}\r
130     TimeSerie& operator<< (const TimeSerie& t) {return *this += t;}\r
131 \r
132         TimeSerie& operator+=(const fvec& v) {data.push_back(v); timestamps.push_back(timestamps.back()+1); return *this;}\r
133         TimeSerie operator+(const fvec& v) const {TimeSerie a = *this; a+=v; return a;}\r
134     TimeSerie& operator<< (const fvec& v) {return *this += v;}\r
135 };\r
136 \r
137 class DatasetManager\r
138 {\r
139 protected:\r
140         static u32 IDCount;\r
141 \r
142         u32 ID;\r
143 \r
144         int size; // the samples size (dimension)\r
145 \r
146         std::vector< fvec > samples;\r
147 \r
148         std::vector< ipair > sequences;\r
149 \r
150         std::vector<dsmFlags> flags;\r
151 \r
152         std::vector<Obstacle> obstacles;\r
153 \r
154         std::vector<TimeSerie> series;\r
155 \r
156         RewardMap rewards;\r
157 \r
158         ivec labels;\r
159 \r
160         u32 *perm;\r
161 \r
162 public:\r
163     bool bProjected;\r
164     std::map<int, std::vector<std::string> > categorical;\r
165 \r
166 public:\r
167     DatasetManager(const int dimension = 2);\r
168         ~DatasetManager();\r
169 \r
170     void Randomize(const int seed=-1);\r
171         void Clear();\r
172     double Compare(const fvec sample) const;\r
173 \r
174     int GetSize() const {return size;}\r
175     int GetCount() const {return samples.size();}\r
176     int GetDimCount() const;\r
177     std::pair<fvec, fvec> GetBounds() const;\r
178     static u32 GetClassCount(const ivec classes);\r
179 \r
180         // functions to manage samples\r
181     void AddSample(const fvec sample, const int label = 0, const dsmFlags flag = _UNUSED);\r
182     void AddSamples(const std::vector< fvec > samples, const ivec newLabels=ivec(), const std::vector<dsmFlags> newFlags=std::vector<dsmFlags>());\r
183     void AddSamples(const DatasetManager &newSamples);\r
184     void RemoveSample(const unsigned int index);\r
185     void RemoveSamples(ivec indices);\r
186 \r
187     fvec GetSample(const int index=0) const { return (index < samples.size()) ? samples[index] : fvec(); }\r
188     fvec GetSampleDim(const int index, const ivec inputDims, const int outputDim=-1) const;\r
189     std::vector< fvec > GetSamples() const {return samples;}\r
190     std::vector< fvec > GetSamples(const u32 count, const dsmFlags flag=_UNUSED, const dsmFlags replaceWith=_TRAIN);\r
191     std::vector< fvec > GetSampleDims(const ivec inputDims, const int outputDim=-1) const ;\r
192     void SetSample(const int index, const fvec sample);\r
193     void SetSamples(const std::vector<fvec> samples){this->samples = samples;}\r
194 \r
195     int GetLabel(const int index) const {return index < labels.size() ? labels[index] : 0;}\r
196     ivec GetLabels() const {return labels;}\r
197         void SetLabel(int index, int label){if(index<labels.size())labels[index] = label;}\r
198     void SetLabels(ivec labels){this->labels = labels;}\r
199 \r
200     std::string GetCategorical(const int dimension,const  int value) const ;\r
201     bool IsCategorical(const int dimension) const ;\r
202 \r
203         // functions to manage sequences\r
204     void AddSequence(const int start, const int stop);\r
205     void AddSequence(const ipair newSequence);\r
206     void AddSequences(const std::vector< ipair > newSequences);\r
207     void RemoveSequence(const unsigned int index);\r
208 \r
209     ipair const GetSequence(const unsigned int index) const {return index < sequences.size() ? sequences[index] : ipair(-1,-1);}\r
210     std::vector< ipair > GetSequences() const {return sequences;}\r
211     std::vector< std::vector<fvec> > GetTrajectories(const int resampleType, const int resampleCount, const int centerType, const float dT, const int zeroEnding) const ;\r
212 \r
213         // functions to manage obstacles\r
214     void AddObstacle(const Obstacle o){obstacles.push_back(o);}\r
215     void AddObstacle(const fvec center, const fvec axes, const float angle, const fvec power, const fvec repulsion);\r
216     void AddObstacles(const std::vector<Obstacle> newObstacles);\r
217     void RemoveObstacle(const unsigned int index);\r
218     std::vector< Obstacle > GetObstacles() const {return obstacles;}\r
219     Obstacle GetObstacle(const unsigned int index) const {return index < obstacles.size() ? obstacles[index] : Obstacle();}\r
220 \r
221         // functions to manage rewards\r
222     void AddReward(const float *values, const ivec size, const fvec lowerBoundary, const fvec higherBoundary);\r
223     RewardMap *GetReward() {return &rewards;}\r
224 \r
225         // functions to manage time series\r
226     void AddTimeSerie(const std::string name, const std::vector<fvec> data, const std::vector<long int> timestamps=std::vector<long int>());\r
227     void AddTimeSerie(const TimeSerie serie);\r
228     void AddTimeSeries(const std::vector< TimeSerie > newTimeSeries);\r
229     void RemoveTimeSerie(const unsigned int index);\r
230     std::vector<TimeSerie>& GetTimeSeries() {return series;}\r
231 \r
232         // functions to manage flags\r
233     dsmFlags GetFlag(const int index) const {return index < flags.size() ? flags[index] : _UNUSED;}\r
234     void SetFlag(const int index, const dsmFlags flag){if(index < flags.size()) flags[index] = flag;}\r
235     std::vector<dsmFlags> GetFlags() const {return flags;}\r
236     std::vector<bool> GetFreeFlags() const ;\r
237         void ResetFlags();\r
238 \r
239     void Save(const char *filename);\r
240         bool Load(const char *filename);\r
241 };\r
242 \r
243 #endif // _DATASET_MANAGER_H_\r