feat: implement shallow globbing; removed dep glob

This commit is contained in:
Himadri Bhattacharjee
2025-07-19 13:36:39 +05:30
parent faea784d8f
commit c734c81a04
6 changed files with 74 additions and 68 deletions

10
Cargo.lock generated
View File

@@ -1050,12 +1050,6 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.4.10"
@@ -2497,7 +2491,7 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "silos"
version = "5.0.0"
version = "5.1.0"
dependencies = [
"anyhow",
"candle-core",
@@ -2505,11 +2499,9 @@ dependencies = [
"candle-transformers",
"clap",
"derive_more",
"glob",
"hf-hub",
"hora",
"kdl",
"serde",
"serde_json",
"tokenizers",
"tokio",

View File

@@ -1,6 +1,6 @@
[package]
name = "silos"
version = "5.1.0"
version = "5.2.0"
edition = "2024"
[dependencies]
@@ -10,11 +10,9 @@ candle-nn = "0.9.1"
candle-transformers = "0.9.1"
clap = { version = "4.5.39", features = ["derive"] }
derive_more = { version = "2.0.1", features = ["display", "error"] }
glob = "0.3.2"
hf-hub = "0.4.2"
hora = "0.1.1"
kdl = "6.3.4"
serde = "1.0.219"
serde_json = "1.0.140"
tokenizers = "0.21.1"
tracing = "0.1.41"

View File

@@ -17,7 +17,7 @@ pub(crate) struct Args {
/// Path to the directory containing `generate` and `refactor` snippets.
#[arg(long, default_value = "./snippets")]
pub(crate) snippets: String,
pub(crate) snippets: std::path::PathBuf,
}
impl Args {

View File

@@ -1,4 +1,4 @@
use anyhow::{Context, Error as E, Result, bail};
use anyhow::{Context, Error as E, Result};
use clap::Parser;
use hora::core::{ann_index::ANNIndex, metrics::Metric::Euclidean};
use hora::index::hnsw_idx::HNSWIndex;
@@ -14,18 +14,7 @@ mod embed;
mod lsp;
mod mutation;
mod state;
fn path_to_parent_base(p: &std::path::Path) -> Result<String> {
let Some(parent) = p
.parent()
.and_then(|v| v.file_name())
.and_then(|v| v.to_str())
.map(|v| v.to_string())
else {
bail!("failed to parse snippets path, make sure the directory structure is valid");
};
Ok(parent)
}
mod sources;
#[tokio::main]
async fn main() -> Result<()> {
@@ -36,29 +25,27 @@ async fn main() -> Result<()> {
let mut dict = HashMap::default();
let dimensions = embed.hidden_size;
let paths = glob::glob(&format!("{}/generate/*/*.kdl", args.snippets))?;
for path in paths {
let path = path?;
let parent = path_to_parent_base(&path)?;
for (language, paths) in sources::rule_files(args.snippets.join("generate"))? {
for path in paths {
let current_lang_index = dict
.entry(language.clone())
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
let current_lang_index = dict
.entry(parent)
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
let doc_str = std::fs::read_to_string(&path)?;
let doc: KdlDocument = doc_str
.parse()
.context(format!("failed to parse KDL: {}", path.display()))?;
let doc_str = std::fs::read_to_string(&path)?;
let doc: KdlDocument = doc_str
.parse()
.context(format!("failed to parse KDL: {}", path.display()))?;
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
continue;
};
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
continue;
};
current_lang_index
.add(&embed.embed(desc)?, body.to_string())
.map_err(E::msg)?;
let Some(desc) = doc.get_arg("desc").and_then(|v| v.as_string()) else {
continue;
};
let Some(body) = doc.get_arg("body").and_then(|v| v.as_string()) else {
continue;
};
current_lang_index
.add(&embed.embed(desc)?, body.to_string())
.map_err(E::msg)?;
}
}
for index in dict.values_mut() {
@@ -67,36 +54,34 @@ async fn main() -> Result<()> {
.map_err(E::msg)?;
}
let paths = glob::glob(&format!("{}/refactor/*/*.kdl", args.snippets))?;
let mut refactor_dict = HashMap::new();
let mut mutations_collection = vec![];
for (i, path) in paths.enumerate() {
let path = path?;
let parent = path_to_parent_base(&path)?;
for (language, paths) in sources::rule_files(args.snippets.join("refactor"))? {
for path in paths {
let mutations = mutation::from_path(path)?;
let current_lang_index = refactor_dict
.entry(language.clone())
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
let mutations = mutation::from_path(path)?;
let current_lang_index = refactor_dict
.entry(parent)
.or_insert_with(|| HNSWIndex::new(dimensions, &Default::default()));
current_lang_index
.add(&embed.embed(&mutations.description)?, i)
.map_err(E::msg)?;
mutations_collection.push(mutations);
current_lang_index
.add(&embed.embed(&mutations.description)?, mutations_collection.len())
.map_err(E::msg)?;
mutations_collection.push(mutations);
}
}
for index in refactor_dict.values_mut() {
index.build(Euclidean).map_err(E::msg)?;
}
let appstate = State {
let appstate = State::new(
embed,
generate: state::Generate { dict },
refactor: state::Refactor {
state::Generate { dict },
state::Refactor {
dict: refactor_dict,
mutations_collection,
},
};
);
let stdin = tokio::io::stdin();
let stdout = tokio::io::stdout();

25
src/sources.rs Normal file
View File

@@ -0,0 +1,25 @@
use std::{fs, io, path::{Path, PathBuf}, collections::HashMap};
pub fn rule_files<P: AsRef<Path>>(path: P) -> io::Result<HashMap<String, Vec<PathBuf>>> {
let per_language_dirs: Vec<_> = fs::read_dir(path)?
.filter_map(|res| res.ok())
.map(|direntry| direntry.path())
.filter(|dir| dir.is_dir()).collect();
let mut basename_to_paths = HashMap::new();
for language_dir in per_language_dirs {
let Some(dirname) = language_dir.file_stem().and_then(|v|v.to_str()).map(|v| v.to_string()) else {
continue;
};
let rule_file_paths: Vec<_> = fs::read_dir(&language_dir)?
.filter_map(|res| res.ok())
.map(|entry| entry.path())
.filter(|file| file.is_file() && file.extension().is_some_and(|ext| ext == "kdl"))
.map(|path| path.to_path_buf())
.collect();
basename_to_paths.insert(dirname, rule_file_paths);
}
Ok(basename_to_paths)
}
// fn prebuilt_index();

View File

@@ -93,13 +93,19 @@ impl Generate {
}
pub struct State {
// TODO: create new constructor and private these fields
pub embed: crate::embed::Embed,
pub generate: Generate,
pub refactor: Refactor,
embed: crate::embed::Embed,
generate: Generate,
refactor: Refactor,
}
impl State {
pub fn new(embed: crate::embed::Embed, generate: Generate, refactor: Refactor) -> Self {
Self {
embed,
generate,
refactor,
}
}
pub fn generate(&self, lang: &str, prompt: &str, top_k: usize) -> Result<Vec<String>, Error> {
let Ok(target) = self.embed.embed(prompt) else {
return Err(Error::EmbedFailed);