Skip to content

Commit 72882b4

Browse files
committed
Refactored projections
1 parent d266efe commit 72882b4

3 files changed

Lines changed: 581 additions & 31 deletions

File tree

backend/kangas/datatypes/embedding.py

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,34 @@
1212
######################################################
1313

1414
import json
15+
import random
16+
import time
1517

16-
from ..server.utils import pickle_dumps
18+
from ..server.utils import Cache, pickle_dumps
1719
from .base import Asset
18-
from .utils import flatten, get_color, get_file_extension, is_valid_file_path
20+
from .utils import get_color, get_file_extension, is_valid_file_path
21+
22+
PROJECTION_DIMENSIONS = 50
23+
24+
SAMPLE_CACHE = Cache(100)
25+
26+
27+
def prepare_embedding(embedding, dimensions, seed):
28+
if len(embedding) <= dimensions:
29+
return embedding
30+
31+
key = (seed, dimensions)
32+
if not SAMPLE_CACHE.contains(key):
33+
random.seed(seed)
34+
indices = list(range(len(embedding)))
35+
random.shuffle(indices)
36+
SAMPLE_CACHE.put(key, set(indices[:dimensions]))
37+
38+
indices = SAMPLE_CACHE.get(key)
39+
40+
return [v for i, v in enumerate(embedding) if i in indices]
41+
42+
SAMPLE_CACHE.get(key)
1943

2044

2145
class Embedding(Asset):
@@ -37,6 +61,7 @@ def __init__(
3761
metadata=None,
3862
source=None,
3963
unserialize=False,
64+
dimensions=PROJECTION_DIMENSIONS,
4065
):
4166
"""
4267
Create an embedding vector.
@@ -53,6 +78,7 @@ def __init__(
5378
include: (bool) whether to include this vector when determining the
5479
projection. Useful if you want to see one part of the datagrid in
5580
the project of another.
81+
dimensions: (int) maximum number of dimensions
5682
5783
Example:
5884
@@ -88,6 +114,7 @@ def __init__(
88114
self.metadata["color"] = color
89115
self.metadata["projection"] = projection
90116
self.metadata["include"] = include
117+
self.metadata["dimensions"] = dimensions
91118

92119
if file_name:
93120
if is_valid_file_path(file_name):
@@ -98,6 +125,7 @@ def __init__(
98125
"name": name,
99126
"color": color,
100127
"text": text,
128+
"dimensions": dimensions,
101129
}
102130
)
103131
self.metadata["extension"] = get_file_extension(file_name)
@@ -106,7 +134,14 @@ def __init__(
106134
raise ValueError("file not found: %r" % file_name)
107135
else:
108136
self.asset_data = json.dumps(
109-
{"vector": embedding, "name": name, "color": color, "text": text}
137+
{
138+
"vector": embedding,
139+
"name": name,
140+
"color": color,
141+
"text": text,
142+
"dimensions": dimensions,
143+
"include": include,
144+
}
110145
)
111146
if metadata:
112147
self.metadata.update(metadata)
@@ -135,39 +170,48 @@ def get_statistics(cls, datagrid, col_name, field_name):
135170
stddev = None
136171
other = None
137172
name = col_name
173+
seed = time.time() # set the same for all embeddings
138174

139175
projection = None
140176
batch = []
141177
for row in datagrid.conn.execute(
142-
"""SELECT {field_name} as assetId, asset_data, json_extract(asset_metadata, '$.projection'), json_extract(asset_metadata, '$.include') from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
178+
"""SELECT {field_name} as assetId, asset_data, asset_metadata from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
143179
field_name=field_name
144180
)
145181
):
182+
asset_id, asset_data_json, asset_metadata_json = row
183+
if not asset_metadata_json:
184+
continue
185+
186+
asset_metdata = json.loads(asset_metadata_json)
187+
projection = asset_metdata["projection"]
188+
include = asset_metdata["include"]
189+
dimensions = asset_metdata["dimensions"]
190+
146191
# Skip if explicitly False
147-
if row[3] is False:
192+
if not include:
148193
continue
149194

150-
embedding = json.loads(row[1])
151-
vectors = embedding["vector"]
152-
vector = flatten(vectors)
195+
asset_data = json.loads(asset_data_json)
196+
vector = prepare_embedding(asset_data["vector"], dimensions, seed)
153197

154198
batch.append(vector)
155-
if row[2] is None or row[2] == "pca":
199+
if projection is None or projection == "pca":
156200
projection_name = "pca"
157-
elif row[2] == "t-sne":
201+
elif projection == "t-sne":
158202
projection_name = "t-sne"
159-
elif row[2] == "umap":
203+
elif projection == "umap":
160204
projection_name = "umap"
161205

162206
if projection_name == "pca":
163207
from sklearn.decomposition import PCA
164208

165-
projection = PCA()
166-
embedding = projection.fit_transform(np.array(batch))
167-
x_max = float(embedding[:, 0].max())
168-
x_min = float(embedding[:, 0].min())
169-
y_max = float(embedding[:, 1].max())
170-
y_min = float(embedding[:, 1].min())
209+
projection = PCA(n_components=2)
210+
transformed = projection.fit_transform(np.array(batch))
211+
x_max = float(transformed[:, 0].max())
212+
x_min = float(transformed[:, 0].min())
213+
y_max = float(transformed[:, 1].max())
214+
y_min = float(transformed[:, 1].min())
171215
x_span = abs(x_max - x_min)
172216
x_max += x_span * 0.1
173217
x_min -= x_span * 0.1
@@ -181,17 +225,19 @@ def get_statistics(cls, datagrid, col_name, field_name):
181225
"projection": projection_name,
182226
"x_range": [x_min, x_max],
183227
"y_range": [y_min, y_max],
228+
"dimensions": dimensions,
229+
"seed": seed,
184230
}
185231
)
186232
elif projection_name == "t-sne":
187233
from openTSNE import TSNE
188234

189235
projection = TSNE()
190-
embedding = projection.fit(np.array(batch))
191-
x_max = float(embedding[:, 0].max())
192-
x_min = float(embedding[:, 0].min())
193-
y_max = float(embedding[:, 1].max())
194-
y_min = float(embedding[:, 1].min())
236+
transformed = projection.fit(np.array(batch))
237+
x_max = float(transformed[:, 0].max())
238+
x_min = float(transformed[:, 0].min())
239+
y_max = float(transformed[:, 1].max())
240+
y_min = float(transformed[:, 1].min())
195241
x_span = abs(x_max - x_min)
196242
x_max += x_span * 0.1
197243
x_min -= x_span * 0.1
@@ -201,9 +247,11 @@ def get_statistics(cls, datagrid, col_name, field_name):
201247
other = json.dumps(
202248
{
203249
"projection": projection_name,
204-
"embedding": pickle_dumps(embedding),
250+
"pickled_projection": pickle_dumps(transformed),
205251
"x_range": [x_min, x_max],
206252
"y_range": [y_min, y_max],
253+
"dimensions": dimensions,
254+
"seed": seed,
207255
}
208256
)
209257
elif projection_name == "umap":

backend/kangas/server/queries.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,10 +2239,13 @@ def process_projection_asset_ids(
22392239
size,
22402240
default_color,
22412241
color_override=None,
2242+
projection_dimensions=None,
2243+
projection_seed=None,
22422244
):
2245+
from ..datatypes.embedding import prepare_embedding
2246+
22432247
# asset_ids is a list of str
22442248
# side-effect: adds to traces
2245-
22462249
# Turn to string:
22472250
values = "(" + (",".join(["'%s'" % asset_id for asset_id in asset_ids])) + ")"
22482251
if values == "()":
@@ -2253,10 +2256,14 @@ def process_projection_asset_ids(
22532256
)
22542257

22552258
trace_data = {}
2259+
print("using seed", projection_seed)
2260+
22562261
for asset_data_row in cur.execute(sql):
22572262
asset_data_raw = asset_data_row[0]
22582263
asset_data = json.loads(asset_data_raw)
2259-
vector = asset_data["vector"]
2264+
vector_reduced = prepare_embedding(
2265+
asset_data["vector"], projection_dimensions, projection_seed
2266+
)
22602267
if color_override:
22612268
color = color_override
22622269
elif asset_data["color"]:
@@ -2283,7 +2290,7 @@ def process_projection_asset_ids(
22832290
}
22842291

22852292
trace_data[trace_name]["texts"].append(asset_data.get("text"))
2286-
trace_data[trace_name]["vectors"].append(vector)
2293+
trace_data[trace_name]["vectors"].append(vector_reduced)
22872294
trace_data[trace_name]["colors"].append(color)
22882295
trace_data[trace_name]["customdata"].append(row_id)
22892296

@@ -2333,6 +2340,8 @@ def select_projection_data(
23332340
where_expr,
23342341
computed_columns,
23352342
):
2343+
from ..datatypes.embedding import prepare_embedding
2344+
23362345
conn = get_database_connection(dgid)
23372346
cur = conn.cursor()
23382347
unify_computed_columns(computed_columns)
@@ -2350,14 +2359,14 @@ def select_projection_data(
23502359

23512360
pca_eigen_vectors = metadata[column_name]["other"]["pca_eigen_vectors"]
23522361
pca_mean = metadata[column_name]["other"]["pca_mean"]
2353-
projection = PCA()
2362+
projection = PCA(n_components=2)
23542363
projection.components_ = np.array(pca_eigen_vectors)
23552364
projection.mean_ = np.array(pca_mean)
23562365
elif projection_name == "t-sne":
23572366
# FIXME: Trying to prevent an error on first load; race condition?
23582367
from openTSNE import TSNE # noqa
23592368

2360-
ascii_string = metadata[column_name]["other"]["embedding"]
2369+
ascii_string = metadata[column_name]["other"]["pickled_projection"]
23612370
if not PROJECTION_EMBEDDING_CACHE.contains(ascii_string):
23622371
PROJECTION_EMBEDDING_CACHE.put(
23632372
ascii_string, pickle_loads_embedding_unsafe(ascii_string)
@@ -2370,6 +2379,8 @@ def select_projection_data(
23702379
return
23712380

23722381
default_color = get_color(column_name)
2382+
projection_dimensions = metadata[column_name]["other"]["dimensions"]
2383+
projection_seed = metadata[column_name]["other"]["seed"]
23732384

23742385
traces = []
23752386
if asset_id:
@@ -2407,6 +2418,8 @@ def select_projection_data(
24072418
3,
24082419
default_color,
24092420
"lightgray",
2421+
projection_dimensions,
2422+
projection_seed,
24102423
)
24112424
PROJECTION_TRACE_CACHE.put(key, traces)
24122425
# Traces contains projection data; make copy:
@@ -2415,7 +2428,10 @@ def select_projection_data(
24152428
# Next, add the selected asset:
24162429
asset_data_raw = select_asset(dgid, asset_id)
24172430
asset_data = json.loads(asset_data_raw)
2418-
vector = projection.transform(np.array([asset_data["vector"]]))
2431+
vector_reduced = prepare_embedding(
2432+
asset_data["vector"], projection_dimensions, projection_seed
2433+
)
2434+
transformed = projection.transform(np.array([vector_reduced]))
24192435
if asset_data["color"]:
24202436
color = asset_data["color"]
24212437
else:
@@ -2427,8 +2443,8 @@ def select_projection_data(
24272443
text = asset_data.get("text", column_name)
24282444

24292445
trace = {
2430-
"x": [round(vector[0][0], 3)],
2431-
"y": [round(vector[0][1], 3)],
2446+
"x": [transformed[0][0]],
2447+
"y": [transformed[0][1]],
24322448
"text": text,
24332449
"name": text,
24342450
"type": "scatter",
@@ -2473,6 +2489,9 @@ def select_projection_data(
24732489
traces,
24742490
3,
24752491
default_color,
2492+
None,
2493+
projection_dimensions,
2494+
projection_seed,
24762495
)
24772496
PROJECTION_TRACE_CACHE.put(key, traces)
24782497
# Traces contains projection data; make copy:

0 commit comments

Comments
 (0)