1212######################################################
1313
1414import json
15+ import random
16+ import time
1517
16- from ..server .utils import pickle_dumps
18+ from ..server .utils import Cache , pickle_dumps
1719from .base import Asset
18- from .utils import flatten , get_color , get_file_extension , is_valid_file_path
20+ from .utils import get_color , get_file_extension , is_valid_file_path
21+
22+ PROJECTION_DIMENSIONS = 50
23+
24+ SAMPLE_CACHE = Cache (100 )
25+
26+
27+ def prepare_embedding (embedding , dimensions , seed ):
28+ if len (embedding ) <= dimensions :
29+ return embedding
30+
31+ key = (seed , dimensions )
32+ if not SAMPLE_CACHE .contains (key ):
33+ random .seed (seed )
34+ indices = list (range (len (embedding )))
35+ random .shuffle (indices )
36+ SAMPLE_CACHE .put (key , set (indices [:dimensions ]))
37+
38+ indices = SAMPLE_CACHE .get (key )
39+
40+ return [v for i , v in enumerate (embedding ) if i in indices ]
41+
42+ SAMPLE_CACHE .get (key )
1943
2044
2145class Embedding (Asset ):
@@ -37,6 +61,7 @@ def __init__(
3761 metadata = None ,
3862 source = None ,
3963 unserialize = False ,
64+ dimensions = PROJECTION_DIMENSIONS ,
4065 ):
4166 """
4267 Create an embedding vector.
@@ -53,6 +78,7 @@ def __init__(
5378 include: (bool) whether to include this vector when determining the
5479 projection. Useful if you want to see one part of the datagrid in
5580 the project of another.
81+ dimensions: (int) maximum number of dimensions
5682
5783 Example:
5884
@@ -88,6 +114,7 @@ def __init__(
88114 self .metadata ["color" ] = color
89115 self .metadata ["projection" ] = projection
90116 self .metadata ["include" ] = include
117+ self .metadata ["dimensions" ] = dimensions
91118
92119 if file_name :
93120 if is_valid_file_path (file_name ):
@@ -98,6 +125,7 @@ def __init__(
98125 "name" : name ,
99126 "color" : color ,
100127 "text" : text ,
128+ "dimensions" : dimensions ,
101129 }
102130 )
103131 self .metadata ["extension" ] = get_file_extension (file_name )
@@ -106,7 +134,14 @@ def __init__(
106134 raise ValueError ("file not found: %r" % file_name )
107135 else :
108136 self .asset_data = json .dumps (
109- {"vector" : embedding , "name" : name , "color" : color , "text" : text }
137+ {
138+ "vector" : embedding ,
139+ "name" : name ,
140+ "color" : color ,
141+ "text" : text ,
142+ "dimensions" : dimensions ,
143+ "include" : include ,
144+ }
110145 )
111146 if metadata :
112147 self .metadata .update (metadata )
@@ -135,39 +170,48 @@ def get_statistics(cls, datagrid, col_name, field_name):
135170 stddev = None
136171 other = None
137172 name = col_name
173+ seed = time .time () # set the same for all embeddings
138174
139175 projection = None
140176 batch = []
141177 for row in datagrid .conn .execute (
142- """SELECT {field_name} as assetId, asset_data, json_extract( asset_metadata, '$.projection'), json_extract(asset_metadata, '$.include') from datagrid JOIN assets ON assetId = assets.asset_id;""" .format (
178+ """SELECT {field_name} as assetId, asset_data, asset_metadata from datagrid JOIN assets ON assetId = assets.asset_id;""" .format (
143179 field_name = field_name
144180 )
145181 ):
182+ asset_id , asset_data_json , asset_metadata_json = row
183+ if not asset_metadata_json :
184+ continue
185+
186+ asset_metdata = json .loads (asset_metadata_json )
187+ projection = asset_metdata ["projection" ]
188+ include = asset_metdata ["include" ]
189+ dimensions = asset_metdata ["dimensions" ]
190+
146191 # Skip if explicitly False
147- if row [ 3 ] is False :
192+ if not include :
148193 continue
149194
150- embedding = json .loads (row [1 ])
151- vectors = embedding ["vector" ]
152- vector = flatten (vectors )
195+ asset_data = json .loads (asset_data_json )
196+ vector = prepare_embedding (asset_data ["vector" ], dimensions , seed )
153197
154198 batch .append (vector )
155- if row [ 2 ] is None or row [ 2 ] == "pca" :
199+ if projection is None or projection == "pca" :
156200 projection_name = "pca"
157- elif row [ 2 ] == "t-sne" :
201+ elif projection == "t-sne" :
158202 projection_name = "t-sne"
159- elif row [ 2 ] == "umap" :
203+ elif projection == "umap" :
160204 projection_name = "umap"
161205
162206 if projection_name == "pca" :
163207 from sklearn .decomposition import PCA
164208
165- projection = PCA ()
166- embedding = projection .fit_transform (np .array (batch ))
167- x_max = float (embedding [:, 0 ].max ())
168- x_min = float (embedding [:, 0 ].min ())
169- y_max = float (embedding [:, 1 ].max ())
170- y_min = float (embedding [:, 1 ].min ())
209+ projection = PCA (n_components = 2 )
210+ transformed = projection .fit_transform (np .array (batch ))
211+ x_max = float (transformed [:, 0 ].max ())
212+ x_min = float (transformed [:, 0 ].min ())
213+ y_max = float (transformed [:, 1 ].max ())
214+ y_min = float (transformed [:, 1 ].min ())
171215 x_span = abs (x_max - x_min )
172216 x_max += x_span * 0.1
173217 x_min -= x_span * 0.1
@@ -181,17 +225,19 @@ def get_statistics(cls, datagrid, col_name, field_name):
181225 "projection" : projection_name ,
182226 "x_range" : [x_min , x_max ],
183227 "y_range" : [y_min , y_max ],
228+ "dimensions" : dimensions ,
229+ "seed" : seed ,
184230 }
185231 )
186232 elif projection_name == "t-sne" :
187233 from openTSNE import TSNE
188234
189235 projection = TSNE ()
190- embedding = projection .fit (np .array (batch ))
191- x_max = float (embedding [:, 0 ].max ())
192- x_min = float (embedding [:, 0 ].min ())
193- y_max = float (embedding [:, 1 ].max ())
194- y_min = float (embedding [:, 1 ].min ())
236+ transformed = projection .fit (np .array (batch ))
237+ x_max = float (transformed [:, 0 ].max ())
238+ x_min = float (transformed [:, 0 ].min ())
239+ y_max = float (transformed [:, 1 ].max ())
240+ y_min = float (transformed [:, 1 ].min ())
195241 x_span = abs (x_max - x_min )
196242 x_max += x_span * 0.1
197243 x_min -= x_span * 0.1
@@ -201,9 +247,11 @@ def get_statistics(cls, datagrid, col_name, field_name):
201247 other = json .dumps (
202248 {
203249 "projection" : projection_name ,
204- "embedding " : pickle_dumps (embedding ),
250+ "pickled_projection " : pickle_dumps (transformed ),
205251 "x_range" : [x_min , x_max ],
206252 "y_range" : [y_min , y_max ],
253+ "dimensions" : dimensions ,
254+ "seed" : seed ,
207255 }
208256 )
209257 elif projection_name == "umap" :
0 commit comments