diff --git a/Cargo.lock b/Cargo.lock index c8a7df7..96995c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1050,12 +1050,6 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" -[[package]] -name = "glob" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" - [[package]] name = "h2" version = "0.4.10" @@ -2497,7 +2491,7 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "silos" -version = "5.0.0" +version = "5.1.0" dependencies = [ "anyhow", "candle-core", @@ -2505,11 +2499,9 @@ dependencies = [ "candle-transformers", "clap", "derive_more", - "glob", "hf-hub", "hora", "kdl", - "serde", "serde_json", "tokenizers", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 3e343b1..e9d1aaa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "silos" -version = "5.1.0" +version = "5.2.0" edition = "2024" [dependencies] @@ -10,11 +10,9 @@ candle-nn = "0.9.1" candle-transformers = "0.9.1" clap = { version = "4.5.39", features = ["derive"] } derive_more = { version = "2.0.1", features = ["display", "error"] } -glob = "0.3.2" hf-hub = "0.4.2" hora = "0.1.1" kdl = "6.3.4" -serde = "1.0.219" serde_json = "1.0.140" tokenizers = "0.21.1" tracing = "0.1.41" diff --git a/src/args.rs b/src/args.rs index a9b9291..17bc829 100644 --- a/src/args.rs +++ b/src/args.rs @@ -17,7 +17,7 @@ pub(crate) struct Args { /// Path to the directory containing `generate` and `refactor` snippets. #[arg(long, default_value = "./snippets")] - pub(crate) snippets: String, + pub(crate) snippets: std::path::PathBuf, } impl Args { diff --git a/src/main.rs b/src/main.rs index 3e31dad..794c8df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use anyhow::{Context, Error as E, Result, bail}; +use anyhow::{Context, Error as E, Result}; use clap::Parser; use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean}; use hora::index::hnsw_idx::HNSWIndex; @@ -14,18 +14,7 @@ mod embed; mod lsp; mod mutation; mod state; - -fn path_to_parent_base(p: &std::path::Path) -> Result { - let Some(parent) = p - .parent() - .and_then(|v| v.file_name()) - .and_then(|v| v.to_str()) - .map(|v| v.to_string()) - else { - bail!("failed to parse snippets path, make sure the directory structure is valid"); - }; - Ok(parent) -} +mod sources; #[tokio::main] async fn main() -> Result<()> { @@ -36,29 +25,27 @@ async fn main() -> Result<()> { let mut dict = HashMap::default(); let dimensions = embed.hidden_size; - let paths = glob::glob(&format!("{}/generate/*/*.kdl", args.snippets))?; - for path in paths { - let path = path?; - let parent = path_to_parent_base(&path)?; + for (language, paths) in sources::rule_files(args.snippets.join("generate"))? { + for path in paths { + let current_lang_index = dict + .entry(language.clone()) + .or_insert_with(|| HNSWIndex::new(dimensions, &Default::default())); - let current_lang_index = dict - .entry(parent) - .or_insert_with(|| HNSWIndex::new(dimensions, &Default::default())); + let doc_str = std::fs::read_to_string(&path)?; + let doc: KdlDocument = doc_str + .parse() + .context(format!("failed to parse KDL: {}", path.display()))?; - let doc_str = std::fs::read_to_string(&path)?; - let doc: KdlDocument = doc_str - .parse() - .context(format!("failed to parse KDL: {}", path.display()))?; - - let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else { - continue; - }; - let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else { - continue; - }; - current_lang_index - .add(&embed.embed(desc)?, body.to_string()) - .map_err(E::msg)?; + let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else { + continue; + }; + let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else { + continue; + }; + current_lang_index + .add(&embed.embed(desc)?, body.to_string()) + .map_err(E::msg)?; + } } for index in dict.values_mut() { @@ -67,36 +54,34 @@ async fn main() -> Result<()> { .map_err(E::msg)?; } - let paths = glob::glob(&format!("{}/refactor/*/*.kdl", args.snippets))?; let mut refactor_dict = HashMap::new(); let mut mutations_collection = vec![]; - for (i, path) in paths.enumerate() { - let path = path?; - let parent = path_to_parent_base(&path)?; + for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? { + for path in paths { + let mutations = mutation::from_path(path)?; + let current_lang_index = refactor_dict + .entry(language.clone()) + .or_insert_with(|| HNSWIndex::new(dimensions, &Default::default())); - let mutations = mutation::from_path(path)?; - let current_lang_index = refactor_dict - .entry(parent) - .or_insert_with(|| HNSWIndex::new(dimensions, &Default::default())); - - current_lang_index - .add(&embed.embed(&mutations.description)?, i) - .map_err(E::msg)?; - mutations_collection.push(mutations); + current_lang_index + .add(&embed.embed(&mutations.description)?, mutations_collection.len()) + .map_err(E::msg)?; + mutations_collection.push(mutations); + } } for index in refactor_dict.values_mut() { index.build(Euclidean).map_err(E::msg)?; } - let appstate = State { + let appstate = State::new( embed, - generate: state::Generate { dict }, - refactor: state::Refactor { + state::Generate { dict }, + state::Refactor { dict: refactor_dict, mutations_collection, }, - }; + ); let stdin = tokio::io::stdin(); let stdout = tokio::io::stdout(); diff --git a/src/sources.rs b/src/sources.rs new file mode 100644 index 0000000..e6f782e --- /dev/null +++ b/src/sources.rs @@ -0,0 +1,25 @@ +use std::{fs, io, path::{Path, PathBuf}, collections::HashMap}; + +pub fn rule_files>(path: P) -> io::Result>> { + let per_language_dirs: Vec<_> = fs::read_dir(path)? + .filter_map(|res| res.ok()) + .map(|direntry| direntry.path()) + .filter(|dir| dir.is_dir()).collect(); + + let mut basename_to_paths = HashMap::new(); + + for language_dir in per_language_dirs { + let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else { + continue; + }; + let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)? + .filter_map(|res| res.ok()) + .map(|entry| entry.path()) + .filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl")) + .map(|path| path.to_path_buf()) + .collect(); + basename_to_paths.insert(dirname, rule_file_paths); + } + Ok(basename_to_paths) +} +// fn prebuilt_index(); diff --git a/src/state.rs b/src/state.rs index 6aac8cc..fb7672f 100644 --- a/src/state.rs +++ b/src/state.rs @@ -93,13 +93,19 @@ impl Generate { } pub struct State { - // TODO: create new constructor and private these fields - pub embed: crate::embed::Embed, - pub generate: Generate, - pub refactor: Refactor, + embed: crate::embed::Embed, + generate: Generate, + refactor: Refactor, } impl State { + pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self { + Self { + embed, + generate, + refactor, + } + } pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result, Error> { let Ok(target) = self.embed.embed(prompt) else { return Err(Error::EmbedFailed);