feat: implement shallow globbing; removed dep glob
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -1050,12 +1050,6 @@ version = "0.31.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.10"
|
||||
@@ -2497,7 +2491,7 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "silos"
|
||||
version = "5.0.0"
|
||||
version = "5.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"candle-core",
|
||||
@@ -2505,11 +2499,9 @@ dependencies = [
|
||||
"candle-transformers",
|
||||
"clap",
|
||||
"derive_more",
|
||||
"glob",
|
||||
"hf-hub",
|
||||
"hora",
|
||||
"kdl",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "silos"
|
||||
version = "5.1.0"
|
||||
version = "5.2.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
@@ -10,11 +10,9 @@ candle-nn = "0.9.1"
|
||||
candle-transformers = "0.9.1"
|
||||
clap = { version = "4.5.39", features = ["derive"] }
|
||||
derive_more = { version = "2.0.1", features = ["display", "error"] }
|
||||
glob = "0.3.2"
|
||||
hf-hub = "0.4.2"
|
||||
hora = "0.1.1"
|
||||
kdl = "6.3.4"
|
||||
serde = "1.0.219"
|
||||
serde_json = "1.0.140"
|
||||
tokenizers = "0.21.1"
|
||||
tracing = "0.1.41"
|
||||
|
||||
@@ -17,7 +17,7 @@ pub(crate) struct Args {
|
||||
|
||||
/// Path to the directory containing `generate` and `refactor` snippets.
|
||||
#[arg(long, default_value = "./snippets")]
|
||||
pub(crate) snippets: String,
|
||||
pub(crate) snippets: std::path::PathBuf,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
|
||||
87
src/main.rs
87
src/main.rs
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Context, Error as E, Result, bail};
|
||||
use anyhow::{Context, Error as E, Result};
|
||||
use clap::Parser;
|
||||
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
|
||||
use hora::index::hnsw_idx::HNSWIndex;
|
||||
@@ -14,18 +14,7 @@ mod embed;
|
||||
mod lsp;
|
||||
mod mutation;
|
||||
mod state;
|
||||
|
||||
fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
|
||||
let Some(parent) = p
|
||||
.parent()
|
||||
.and_then(|v| v.file_name())
|
||||
.and_then(|v| v.to_str())
|
||||
.map(|v| v.to_string())
|
||||
else {
|
||||
bail!("failed to parse snippets path, make sure the directory structure is valid");
|
||||
};
|
||||
Ok(parent)
|
||||
}
|
||||
mod sources;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -36,29 +25,27 @@ async fn main() -> Result<()> {
|
||||
let mut dict = HashMap::default();
|
||||
let dimensions = embed.hidden_size;
|
||||
|
||||
let paths = glob::glob(&format!("{}/generate/*/*.kdl", args.snippets))?;
|
||||
for path in paths {
|
||||
let path = path?;
|
||||
let parent = path_to_parent_base(&path)?;
|
||||
for (language, paths) in sources::rule_files(args.snippets.join("generate"))? {
|
||||
for path in paths {
|
||||
let current_lang_index = dict
|
||||
.entry(language.clone())
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
let current_lang_index = dict
|
||||
.entry(parent)
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
let doc_str = std::fs::read_to_string(&path)?;
|
||||
let doc: KdlDocument = doc_str
|
||||
.parse()
|
||||
.context(format!("failed to parse KDL: {}", path.display()))?;
|
||||
|
||||
let doc_str = std::fs::read_to_string(&path)?;
|
||||
let doc: KdlDocument = doc_str
|
||||
.parse()
|
||||
.context(format!("failed to parse KDL: {}", path.display()))?;
|
||||
|
||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
current_lang_index
|
||||
.add(&embed.embed(desc)?, body.to_string())
|
||||
.map_err(E::msg)?;
|
||||
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
|
||||
continue;
|
||||
};
|
||||
current_lang_index
|
||||
.add(&embed.embed(desc)?, body.to_string())
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
}
|
||||
|
||||
for index in dict.values_mut() {
|
||||
@@ -67,36 +54,34 @@ async fn main() -> Result<()> {
|
||||
.map_err(E::msg)?;
|
||||
}
|
||||
|
||||
let paths = glob::glob(&format!("{}/refactor/*/*.kdl", args.snippets))?;
|
||||
let mut refactor_dict = HashMap::new();
|
||||
let mut mutations_collection = vec![];
|
||||
for (i, path) in paths.enumerate() {
|
||||
let path = path?;
|
||||
let parent = path_to_parent_base(&path)?;
|
||||
for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? {
|
||||
for path in paths {
|
||||
let mutations = mutation::from_path(path)?;
|
||||
let current_lang_index = refactor_dict
|
||||
.entry(language.clone())
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
let mutations = mutation::from_path(path)?;
|
||||
let current_lang_index = refactor_dict
|
||||
.entry(parent)
|
||||
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
|
||||
|
||||
current_lang_index
|
||||
.add(&embed.embed(&mutations.description)?, i)
|
||||
.map_err(E::msg)?;
|
||||
mutations_collection.push(mutations);
|
||||
current_lang_index
|
||||
.add(&embed.embed(&mutations.description)?, mutations_collection.len())
|
||||
.map_err(E::msg)?;
|
||||
mutations_collection.push(mutations);
|
||||
}
|
||||
}
|
||||
|
||||
for index in refactor_dict.values_mut() {
|
||||
index.build(Euclidean).map_err(E::msg)?;
|
||||
}
|
||||
|
||||
let appstate = State {
|
||||
let appstate = State::new(
|
||||
embed,
|
||||
generate: state::Generate { dict },
|
||||
refactor: state::Refactor {
|
||||
state::Generate { dict },
|
||||
state::Refactor {
|
||||
dict: refactor_dict,
|
||||
mutations_collection,
|
||||
},
|
||||
};
|
||||
);
|
||||
|
||||
let stdin = tokio::io::stdin();
|
||||
let stdout = tokio::io::stdout();
|
||||
|
||||
25
src/sources.rs
Normal file
25
src/sources.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use std::{fs, io, path::{Path, PathBuf}, collections::HashMap};
|
||||
|
||||
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
|
||||
let per_language_dirs: Vec<_> = fs::read_dir(path)?
|
||||
.filter_map(|res| res.ok())
|
||||
.map(|direntry| direntry.path())
|
||||
.filter(|dir| dir.is_dir()).collect();
|
||||
|
||||
let mut basename_to_paths = HashMap::new();
|
||||
|
||||
for language_dir in per_language_dirs {
|
||||
let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else {
|
||||
continue;
|
||||
};
|
||||
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
|
||||
.filter_map(|res| res.ok())
|
||||
.map(|entry| entry.path())
|
||||
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
|
||||
.map(|path| path.to_path_buf())
|
||||
.collect();
|
||||
basename_to_paths.insert(dirname, rule_file_paths);
|
||||
}
|
||||
Ok(basename_to_paths)
|
||||
}
|
||||
// fn prebuilt_index();
|
||||
14
src/state.rs
14
src/state.rs
@@ -93,13 +93,19 @@ impl Generate {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
// TODO: create new constructor and private these fields
|
||||
pub embed: crate::embed::Embed,
|
||||
pub generate: Generate,
|
||||
pub refactor: Refactor,
|
||||
embed: crate::embed::Embed,
|
||||
generate: Generate,
|
||||
refactor: Refactor,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self {
|
||||
Self {
|
||||
embed,
|
||||
generate,
|
||||
refactor,
|
||||
}
|
||||
}
|
||||
pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result<Vec<String>, Error> {
|
||||
let Ok(target) = self.embed.embed(prompt) else {
|
||||
return Err(Error::EmbedFailed);
|
||||
|
||||
Reference in New Issue
Block a user