├── lua ├── config.ld ├── tests │ ├── test_subscription.lua │ └── test_serialization.lua ├── actionlib │ ├── init.lua │ ├── GoalStatus.lua │ ├── ActionSpec.lua │ └── ServerGoalHandle.lua ├── tf │ ├── TransformBroadcaster.lua │ ├── Transform.lua │ ├── StampedTransform.lua │ └── TransformListener.lua ├── this_node.lua ├── std │ ├── String.lua │ ├── StringMap.lua │ ├── StringVector.lua │ ├── Task.lua │ ├── VariableTable.lua │ └── VariableVector.lua ├── init.lua ├── ServiceServer.lua ├── PointCloud2SerializationHandler.lua ├── AsyncSpinner.lua ├── utils.lua ├── SerializedMessage.lua ├── Rate.lua ├── MessageBuffer.lua ├── ros.lua ├── SrvSpec.lua ├── Subscriber.lua ├── master.lua ├── ServiceClient.lua ├── Time.lua ├── StorageReader.lua ├── CallbackQueue.lua ├── Publisher.lua ├── Duration.lua └── StorageWriter.lua ├── src ├── ros │ ├── service_server.cpp │ ├── async_spinner.cpp │ ├── this_node.cpp │ ├── subscriber.cpp │ ├── rate.cpp │ ├── init.cpp │ ├── publisher.cpp │ ├── message_buffer.cpp │ ├── raw_message.h │ ├── serialized_message.cpp │ ├── service_client.cpp │ ├── duration.cpp │ ├── master.cpp │ ├── point_cloud2.cpp │ ├── callback_queue.cpp │ ├── time.cpp │ ├── message_buffer.h │ ├── torch-ros.h │ └── console.cpp ├── std │ ├── torch-std.h │ ├── exceptions.h │ ├── string.cpp │ ├── string_map.cpp │ ├── string_vector.cpp │ ├── variable_vector.cpp │ ├── variable_table.cpp │ └── variable.cpp ├── tf │ ├── transform_broadcaster.cpp │ ├── transform.cpp │ ├── torch-tf.h │ ├── stamped_transform.cpp │ ├── quaternion.cpp │ └── transform_listener.cpp └── utils.h ├── .gitignore ├── torch-ros.workspace ├── demo ├── robotiq_ft_sensor.lua ├── call_movegroup_action.lua ├── broadcast_transform.lua ├── joystick.lua ├── logging.lua ├── pcl_interop.lua ├── publish.lua ├── getSystemState.lua ├── subscribe.lua ├── action_server.lua ├── simple_action_server.lua ├── advertiseService.lua ├── service_client.lua ├── action_client.lua ├── publish_lena_image.lua ├── publish_multi_array.lua ├── service_throughput_test.lua ├── orbit.lua └── robotiq_c_model.lua ├── torch-ros-scm-1.rockspec ├── README.md ├── LICENSE └── CMakeLists.txt /lua/config.ld: -------------------------------------------------------------------------------- 1 | project='torch-ros' 2 | title='Torch7/Lua Wrapper for ROS' 3 | description='Torch7/Lua Wrapper for ROS' 4 | format='discount' 5 | backtick_references=false 6 | wrap=true 7 | -------------------------------------------------------------------------------- /src/ros/service_server.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include 3 | 4 | ROSIMP(void, ServiceServer, delete)(ros::ServiceServer *ptr) { 5 | delete ptr; 6 | } 7 | 8 | ROSIMP(void, ServiceServer, shutdown)(ros::ServiceServer *self) { 9 | self->shutdown(); 10 | } 11 | 12 | ROSIMP(void, ServiceServer, getService)(ros::ServiceServer *self, std::string *result) { 13 | *result = self->getService(); 14 | } 15 | -------------------------------------------------------------------------------- /src/std/torch-std.h: -------------------------------------------------------------------------------- 1 | #ifndef torch_std_h 2 | #define torch_std_h 3 | 4 | extern "C" { 5 | #include 6 | } 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #define STDIMP(return_type, class_name, name) extern "C" return_type TH_CONCAT_4(std_, class_name, _, name) 13 | 14 | typedef std::vector StringVector; 15 | typedef std::map StringMap; 16 | 17 | #endif // torch_std_h 18 | -------------------------------------------------------------------------------- /lua/tests/test_subscription.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | 3 | ros.init('rose') 4 | 5 | spinner = ros.AsyncSpinner() 6 | spinner:start() 7 | 8 | nodehandle = ros.NodeHandle() 9 | 10 | subscriber = nodehandle:subscribe("chatter", 'std_msgs/String', 100) 11 | 12 | while ros.ok() do 13 | sys.sleep(0.1) 14 | while subscriber:hasMessage() do 15 | local msg = subscriber:read() 16 | print(msg) 17 | end 18 | end 19 | 20 | subscriber:shutdown() 21 | ros.shutdown() 22 | -------------------------------------------------------------------------------- /src/tf/transform_broadcaster.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-tf.h" 2 | #include 3 | 4 | TFIMP(tf::TransformBroadcaster *, TransformBroadcaster, new)() { 5 | return new tf::TransformBroadcaster(); 6 | } 7 | 8 | TFIMP(void, TransformBroadcaster, delete)(tf::TransformBroadcaster *self) { 9 | delete self; 10 | } 11 | 12 | TFIMP(void, TransformBroadcaster, sendTransform)(tf::TransformBroadcaster *self, tf::StampedTransform *transform) { 13 | self->sendTransform(*transform); 14 | } 15 | -------------------------------------------------------------------------------- /src/ros/async_spinner.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(ros::AsyncSpinner*, AsyncSpinner, new)(uint32_t thread_count) { 4 | return new ros::AsyncSpinner(thread_count); 5 | } 6 | 7 | ROSIMP(void, AsyncSpinner, delete)(ros::AsyncSpinner *self) { 8 | delete self; 9 | } 10 | 11 | ROSIMP(bool, AsyncSpinner, canStart)(ros::AsyncSpinner *self) { 12 | return self->canStart(); 13 | } 14 | 15 | ROSIMP(void, AsyncSpinner, start)(ros::AsyncSpinner *self) { 16 | self->start(); 17 | } 18 | 19 | ROSIMP(void, AsyncSpinner, stop)(ros::AsyncSpinner *self) { 20 | self->stop(); } 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Lua sources 2 | luac.out 3 | build/ 4 | doc/ 5 | 6 | # luarocks build files 7 | *.src.rock 8 | *.zip 9 | *.tar.gz 10 | 11 | # Object files 12 | *.o 13 | *.os 14 | *.ko 15 | *.obj 16 | *.elf 17 | 18 | # Precompiled Headers 19 | *.gch 20 | *.pch 21 | 22 | # Libraries 23 | *.lib 24 | *.a 25 | *.la 26 | *.lo 27 | *.def 28 | *.exp 29 | 30 | # Shared objects (inc. Windows DLLs) 31 | *.dll 32 | *.so 33 | *.so.* 34 | *.dylib 35 | 36 | # Executables 37 | *.exe 38 | *.out 39 | *.app 40 | *.i*86 41 | *.x86_64 42 | *.hex 43 | 44 | # Codelite 45 | *.tags 46 | .codelite/ 47 | -------------------------------------------------------------------------------- /src/ros/this_node.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include "../std/torch-std.h" 3 | #include 4 | 5 | ROSIMP(const char *, ThisNode, getName)() { 6 | return ros::this_node::getName().c_str(); 7 | } 8 | 9 | ROSIMP(const char *, ThisNode, getNamespace)() { 10 | return ros::this_node::getNamespace().c_str(); 11 | } 12 | 13 | ROSIMP(void, ThisNode, getAdvertisedTopics)(StringVector *topics) { 14 | ros::this_node::getAdvertisedTopics(*topics); 15 | } 16 | 17 | ROSIMP(void, ThisNode, getSubscribedTopics)(StringVector *topics) { 18 | ros::this_node::getSubscribedTopics(*topics); 19 | } 20 | -------------------------------------------------------------------------------- /torch-ros.workspace: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /src/ros/subscriber.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(ros::Subscriber*, Subscriber, clone)(ros::Subscriber *self) { 4 | return new ros::Subscriber(*self); 5 | } 6 | 7 | ROSIMP(void, Subscriber, delete)(ros::Subscriber *self) { 8 | delete self; 9 | } 10 | 11 | ROSIMP(void, Subscriber, shutdown)(ros::Subscriber *self) { 12 | self->shutdown(); 13 | } 14 | 15 | ROSIMP(void, Subscriber, getTopic)(ros::Subscriber *self, std::string *output) { 16 | *output = self->getTopic(); 17 | } 18 | 19 | ROSIMP(int, Subscriber, getNumPublishers)(ros::Subscriber *self) { 20 | return static_cast(self->getNumPublishers()); 21 | } 22 | -------------------------------------------------------------------------------- /src/std/exceptions.h: -------------------------------------------------------------------------------- 1 | #ifndef _exceptions_h 2 | #define _exceptions_h 3 | 4 | #include 5 | 6 | namespace xamla { 7 | 8 | class NotImplementedException 9 | : public std::logic_error { 10 | public: 11 | NotImplementedException() 12 | : logic_error("Function not implemented.") { 13 | } 14 | 15 | NotImplementedException(const std::string &msg) 16 | : logic_error(msg) { 17 | } 18 | }; 19 | 20 | class InvalidTypeException 21 | : public std::runtime_error { 22 | public: 23 | InvalidTypeException(const std::string& reason) 24 | : runtime_error(reason) { 25 | } 26 | }; 27 | 28 | } // xamla 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /demo/robotiq_ft_sensor.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | 3 | --[[ 4 | make sure robitq force torque sensor node is running, e.g. run: 5 | 6 | rosrun robotiq_force_torque_sensor rq_sensor 7 | ]] 8 | 9 | ros.init('read_ft_sensor') 10 | 11 | spinner = ros.AsyncSpinner() 12 | spinner:start() 13 | 14 | local nodehandle = ros.NodeHandle() 15 | 16 | local WrenchStamped_spec = ros.MsgSpec('geometry_msgs/WrenchStamped') 17 | 18 | local wrench_input = nodehandle:subscribe("/wrench", WrenchStamped_spec, 100) 19 | 20 | while ros.ok() do 21 | local msg = wrench_input:read(100) 22 | print(msg) 23 | end 24 | 25 | wrench_input:shutdown() 26 | ros.shutdown() 27 | -------------------------------------------------------------------------------- /demo/call_movegroup_action.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | require 'ros.actionlib.ActionSpec' 3 | require 'ros.actionlib.ActionClient' 4 | 5 | 6 | ros.init('call_movegroup_demo') 7 | ros.console.initialize() 8 | ros.console.get_logger('ActionClient') 9 | ros.console.set_logger_level('ActionClient', ros.console.Level.Debug) 10 | ros.DEBUG_NAMED('ActionClient', 'bulb') 11 | 12 | local move_group_action_spec = ros.actionlib.ActionSpec('moveit_msgs/MoveGroup') 13 | 14 | local client = ros.actionlib.ActionClient(move_group_action_spec, 'move_group') 15 | 16 | print('waiting for action server to start ...') 17 | client:waitForActionServerToStart() 18 | print('ready.') 19 | 20 | ros.shutdown() 21 | -------------------------------------------------------------------------------- /torch-ros-scm-1.rockspec: -------------------------------------------------------------------------------- 1 | package = "torch-ros" 2 | version = "scm-1" 3 | 4 | source = { 5 | url = "git://github.com/Xamla/torch-ros.git", 6 | } 7 | 8 | description = { 9 | summary = "ROS bindings for Torch7", 10 | detailed = [[ 11 | ]], 12 | homepage = "https://github.com/Xamla/torch-ros", 13 | license = "BSD" 14 | } 15 | 16 | dependencies = { 17 | "torch >= 7.0", 18 | "md5 >= 1.2-1", 19 | "ldoc >= 1.4.4-1" 20 | } 21 | 22 | build = { 23 | type = "command", 24 | build_command = [[ 25 | cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) 26 | ]], 27 | install_command = "cd build && $(MAKE) install" 28 | } 29 | -------------------------------------------------------------------------------- /lua/actionlib/init.lua: -------------------------------------------------------------------------------- 1 | --- ROS actionlib 2 | -- @module ros.actionlib 3 | -- 4 | -- General documentation about the actionlib 5 | -- http://wiki.ros.org/actionlib 6 | -- http://wiki.ros.org/actionlib/DetailedDescription 7 | -- 8 | -- C++ & Python source code: 9 | -- https://github.com/ros/actionlib/tree/indigo-devel/include/actionlib 10 | 11 | local ros = require 'ros.env' 12 | require 'ros.ros' 13 | require 'ros.Time' 14 | require 'ros.Duration' 15 | require 'ros.console' 16 | require 'ros.StorageWriter' 17 | require 'ros.StorageReader' 18 | require 'ros.MsgSpec' 19 | require 'ros.Message' 20 | require 'ros.NodeHandle' 21 | require 'ros.actionlib.ActionSpec' 22 | local actionlib = ros.actionlib 23 | 24 | 25 | return actionlib 26 | -------------------------------------------------------------------------------- /src/ros/rate.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(ros::Rate*, Rate, new)(double frequence) { 4 | return new ros::Rate(frequence); 5 | } 6 | 7 | ROSIMP(void, Rate, delete)(ros::Rate *self) { 8 | delete self; 9 | } 10 | 11 | ROSIMP(ros::Rate*, Rate, clone)(ros::Rate *self) { 12 | return new ros::Rate(*self); 13 | } 14 | 15 | ROSIMP(void, Rate, reset)(ros::Rate *self) { 16 | self->reset(); 17 | } 18 | 19 | ROSIMP(void, Rate, sleep)(ros::Rate *self) { 20 | self->sleep(); 21 | } 22 | 23 | ROSIMP(void, Rate, expectedCycleTime)(ros::Rate *self, ros::Duration* output) { 24 | *output = self->expectedCycleTime(); 25 | } 26 | 27 | ROSIMP(void, Rate, cycleTime)(ros::Rate *self, ros::Duration* output) { 28 | *output = self->cycleTime(); 29 | } 30 | -------------------------------------------------------------------------------- /lua/tf/TransformBroadcaster.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local tf = ros.tf 6 | 7 | local TransformBroadcaster = torch.class('tf.TransformBroadcaster', tf) 8 | 9 | function init() 10 | local TransformBroadcaster_method_names = { 11 | "new", 12 | "delete", 13 | "sendTransform" 14 | } 15 | 16 | return utils.create_method_table("tf_TransformBroadcaster_", TransformBroadcaster_method_names) 17 | end 18 | 19 | local f = init() 20 | 21 | function TransformBroadcaster:__init() 22 | self.o = f.new() 23 | end 24 | 25 | function TransformBroadcaster:sendTransform(stampedTransform) 26 | f.sendTransform(self.o, stampedTransform:cdata()) 27 | end 28 | -------------------------------------------------------------------------------- /src/std/string.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | #include 3 | 4 | STDIMP(std::string*, string, new)(const char* s = 0, size_t len = 0) { 5 | if (!s || !len) 6 | return new std::string(); 7 | else 8 | return new std::string(s, len); 9 | } 10 | 11 | STDIMP(void, string, delete)(std::string *self) { 12 | delete self; 13 | } 14 | 15 | STDIMP(std::string*, string, clone)(std::string *self) { 16 | return new std::string(*self); 17 | } 18 | 19 | STDIMP(void, string, assign)(std::string *self, const char *s, size_t len) { 20 | self->assign(s, len); 21 | } 22 | 23 | STDIMP(int, string, length)(std::string *self) { 24 | return self->length(); 25 | } 26 | 27 | STDIMP(const char*, string, c_str)(std::string *self) { 28 | return self->c_str(); 29 | } 30 | -------------------------------------------------------------------------------- /demo/broadcast_transform.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | tf = ros.tf 3 | 4 | --[[ 5 | 6 | tf_echo can be used to verify that the transform is correctly broadcasted: 7 | 8 | rosrun tf tf_echo /world /tomato 9 | 10 | ]] 11 | 12 | ros.init('example_TransformBroadcaster') 13 | local sp = ros.AsyncSpinner() 14 | sp:start() 15 | print('hallo') 16 | b = tf.TransformBroadcaster() 17 | 18 | for i=1,10 do 19 | local t = tf.Transform() 20 | t:setOrigin({ i, -i, i*2 }) 21 | local rot = tf.Quaternion() 22 | rot:setRPY(i*10, i*20, i*30, true) -- last argument indicates that angle is specified in degrees 23 | t:setRotation(rot) 24 | 25 | local st = tf.StampedTransform(t, ros.Time.now(), 'world', 'tomato') 26 | 27 | print('Sending:') 28 | print(st) 29 | 30 | b:sendTransform(st) 31 | ros.Duration(1):sleep() 32 | end 33 | -------------------------------------------------------------------------------- /src/ros/init.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(void, _, init)(const std::map *remappings, const char *name, uint32_t options) { 4 | ros::init(*remappings, name, options); 5 | } 6 | 7 | ROSIMP(void, _, shutdown)() { 8 | ros::shutdown(); 9 | } 10 | 11 | ROSIMP(void, _, spinOnce)() { 12 | ros::spinOnce(); 13 | } 14 | 15 | ROSIMP(void, _, requestShutdown)() { 16 | ros::requestShutdown(); 17 | } 18 | 19 | ROSIMP(bool, _, isInitialized)() { 20 | return ros::isInitialized(); 21 | } 22 | 23 | ROSIMP(bool, _, isStarted)() { 24 | return ros::isStarted(); 25 | } 26 | 27 | ROSIMP(bool, _, isShuttingDown)() { 28 | return ros::isShuttingDown(); 29 | } 30 | 31 | ROSIMP(bool, _, ok)() { 32 | return ros::ok(); 33 | } 34 | 35 | ROSIMP(void, _, waitForShutdown)() { 36 | return ros::waitForShutdown(); 37 | } 38 | -------------------------------------------------------------------------------- /demo/joystick.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Subscribe to joystick input and dump messages. 3 | 4 | 1. install joy package: 5 | $ sudo apt-get install ros-indigo-joy 6 | 7 | 2. configure input device to use: 8 | $ rosparam set joy_node/dev "/dev/input/jsX" 9 | 10 | 3. run joy node: 11 | $ rosrun joy joy_node 12 | 13 | further help to get joy package running: 14 | http://wiki.ros.org/joy/Tutorials/ConfiguringALinuxJoystick 15 | ]] 16 | 17 | ros = require 'ros' 18 | 19 | ros.init('joystick_demo') 20 | nh = ros.NodeHandle() 21 | 22 | joy = nh:subscribe('joy', 'sensor_msgs/Joy', 100) 23 | print('Subscribed to \'joy\' node. Please start using your joystick.') 24 | --print(joy.msg_spec) 25 | 26 | d = ros.Duration(0.0025) 27 | while ros.ok() do 28 | ros.spinOnce() 29 | d:sleep() 30 | while joy:hasMessage() do 31 | local msg = joy:read() 32 | print(msg) 33 | end 34 | end 35 | 36 | joy:shutdown() 37 | ros.shutdown() 38 | -------------------------------------------------------------------------------- /demo/logging.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | 3 | ros.init('logging_demo') 4 | 5 | -- very simple logging 6 | ros.FATAL('Oh, noooo...') 7 | ros.ERROR('Something bad happened.') 8 | ros.WARN('This is a warning.') 9 | ros.INFO('I would like to inform you about the current state.') 10 | ros.DEBUG('A very verbose debug message with useless numeric output: %d', 123) 11 | 12 | ros.DEBUG_COND(true, 'Messages can also have conditions') 13 | ros.DEBUG_COND(false, 'This message will not be printed...') 14 | 15 | for i = 1, 10 do 16 | ros.ROS_ERROR_ONCE('throttle', 'This messege will be printed once') 17 | ros.ROS_INFO_THROTTLE('throttle', 5, 'This messege will be printed every 5 seconds') 18 | ros.ROS_WARN_THROTTLE('throttle1', 2, 'This messege will be printed every 2 seconds') 19 | sys.sleep(2) 20 | end 21 | 22 | -- get name and log level of registered loggers 23 | local names, levels = ros.console.get_loggers() 24 | ros.INFO('Names and levels of loggers: %s %s', names, levels) 25 | 26 | ros.shutdown() 27 | -------------------------------------------------------------------------------- /lua/this_node.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local this_node = {} 8 | ros.this_node = this_node 9 | 10 | function init() 11 | local names = { 12 | 'getName', 13 | 'getNamespace', 14 | 'getAdvertisedTopics', 15 | 'getSubscribedTopics' 16 | } 17 | 18 | return utils.create_method_table("ros_ThisNode_", names) 19 | end 20 | 21 | local f = init() 22 | 23 | function this_node.getName() 24 | return ffi.string(f.getName()) 25 | end 26 | 27 | function this_node.getNamespace() 28 | return ffi.string(f.getNamespace()) 29 | end 30 | 31 | function this_node.getAdvertisedTopics(result) 32 | result = result or std.StringVector() 33 | f.getAdvertisedTopics(result:cdata()) 34 | return result 35 | end 36 | 37 | function this_node.getSubscribedTopics(result) 38 | result = result or std.StringVector() 39 | f.getSubscribedTopics(result:cdata()) 40 | return result 41 | end 42 | -------------------------------------------------------------------------------- /demo/pcl_interop.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | pcl = require 'pcl' 3 | tf = ros.tf 4 | require 'ros.PointCloud2SerializationHandler' 5 | 6 | ros.init('pcl_interop_demo') 7 | nh = ros.NodeHandle() 8 | 9 | local handler = ros.PointCloud2SerializationHandler() 10 | nh:addSerializationHandler(handler) 11 | 12 | function onMessage(msg, header) 13 | print('received point cloud:') 14 | print(msg:toPointCloud():points()) 15 | end 16 | 17 | publisher = nh:advertise('point_cloud_output', 'sensor_msgs/PointCloud2', 10) 18 | 19 | -- establish intraprocess subscription 20 | subscriber = nh:subscribe('point_cloud_output', 'sensor_msgs/PointCloud2', 10) 21 | subscriber:registerCallback(onMessage) 22 | 23 | print('press ctrl+c to exit') 24 | while ros.ok() do 25 | if publisher:getNumSubscribers() > 0 then 26 | local c = pcl.rand(10) -- create dummy point cloud 27 | c:setHeaderFrameId('/map') 28 | print('publishing point cloud:') 29 | print(c:points()) 30 | publisher:publish(c) 31 | end 32 | sys.sleep(0.5) 33 | ros.spinOnce() 34 | end 35 | 36 | nh:shutdown() 37 | ros.shutdown() 38 | -------------------------------------------------------------------------------- /demo/publish.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | 3 | ros.init('publish_demo') 4 | 5 | spinner = ros.AsyncSpinner() 6 | spinner:start() 7 | 8 | nodehandle = ros.NodeHandle() 9 | 10 | string_spec = ros.MsgSpec('std_msgs/String') 11 | 12 | function connect_cb(name, topic) 13 | print("subscriber connected: " .. name .. " (topic: '" .. topic .. "')") 14 | end 15 | 16 | function disconnect_cb(name, topic) 17 | print("subscriber diconnected: " .. name .. " (topic: '" .. topic .. "')") 18 | end 19 | 20 | publisher = nodehandle:advertise("dummy_chat", string_spec, 100, false, connect_cb, disconnect_cb) 21 | ros.spinOnce() 22 | 23 | m = ros.Message(string_spec) 24 | 25 | function run(n) 26 | for i=1,n do 27 | if not ros.ok() then 28 | return 29 | end 30 | if publisher:getNumSubscribers() == 0 then 31 | print('waiting for subscriber') 32 | else 33 | m.data = "Hello this is a string message " .. i 34 | publisher:publish(m) 35 | print(i) 36 | end 37 | sys.sleep(0.1) 38 | ros.spinOnce() 39 | end 40 | end 41 | 42 | run(100) 43 | 44 | ros.shutdown() 45 | -------------------------------------------------------------------------------- /demo/getSystemState.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | 3 | ros.init('getSystemStateDemo') 4 | 5 | -- execute the getSystemState state call via XMLRPC. 6 | -- Master-API reference: http://wiki.ros.org/ROS/Master_API 7 | local status,response,payload = ros.master.execute('getSystemState', ros.this_node.getName()) 8 | 9 | local function printList(list) 10 | for i,v in ipairs(list) do 11 | local topic = v[1] 12 | local nodes = v[2] 13 | print(string.format("Topic: '%s'", topic)) 14 | print(' Nodes:') 15 | for j,n in ipairs(nodes) do 16 | print(' ' .. n) 17 | end 18 | end 19 | end 20 | 21 | if status then 22 | local publishers, subscribers, services = unpack(payload) 23 | print(string.format("# Publishers (%d)", #publishers)) 24 | printList(publishers) 25 | print("") 26 | print(string.format("# Subscribers (%d)", #subscribers)) 27 | printList(subscribers) 28 | print("") 29 | print(string.format("# Services (%d)", #services)) 30 | printList(services) 31 | print("") 32 | else 33 | print('getSystemState request failed.') 34 | end 35 | 36 | ros.shutdown() 37 | -------------------------------------------------------------------------------- /demo/subscribe.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This demo shows how to subscribe to a message topic and receive 3 | incoming messages (in this case simple std_msg/String messages). 4 | You can run `publish.lua` as a message generating counter part 5 | to send string messages that are displayed by this demo. 6 | ]] 7 | 8 | ros = require 'ros' 9 | 10 | 11 | ros.init('subscribe_demo') 12 | 13 | spinner = ros.AsyncSpinner() 14 | spinner:start() 15 | 16 | nodehandle = ros.NodeHandle() 17 | 18 | -- subscribe to dummy_chat topic with 100 messages back-log 19 | -- transport_options (arguments 4 & 5) are optional - used here only for demonstration purposes 20 | subscriber = nodehandle:subscribe("dummy_chat", 'std_msgs/String', 100, { 'udp', 'tcp' }, { tcp_nodelay = true }) 21 | 22 | -- register a callback function that will be triggered from ros.spinOnce() when a message is available. 23 | subscriber:registerCallback(function(msg, header) 24 | print('Header:') 25 | print(header) 26 | print('Message:') 27 | print(msg) 28 | end) 29 | 30 | while ros.ok() do 31 | ros.spinOnce() 32 | sys.sleep(0.1) 33 | end 34 | 35 | subscriber:shutdown() 36 | ros.shutdown() 37 | -------------------------------------------------------------------------------- /lua/std/String.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local String = torch.class('std.String', std) 8 | 9 | function init() 10 | local String_method_names = { 11 | "new", 12 | "clone", 13 | "delete", 14 | "assign", 15 | "length", 16 | "c_str" 17 | } 18 | 19 | return utils.create_method_table("std_string_", String_method_names) 20 | end 21 | 22 | local f = init() 23 | 24 | function String:__init(s) 25 | if type(s) == 'string' then 26 | self.o = f.new(s, #s) 27 | else 28 | self.o = f.new(ffi.NULL, 0) 29 | end 30 | end 31 | 32 | function String:cdata() 33 | return self.o 34 | end 35 | 36 | function String:assign(s) 37 | s = tostring(s) 38 | f.assign(self.o, s, #s) 39 | end 40 | 41 | function String:length() 42 | return f.length(self.o) 43 | end 44 | 45 | function String:get() 46 | return ffi.string(f.c_str(self.o), f.length(self.o)) 47 | end 48 | 49 | function String:__len() 50 | return self:length() 51 | end 52 | 53 | function String:__tostring() 54 | return self:get() 55 | end 56 | -------------------------------------------------------------------------------- /src/ros/publisher.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include "raw_message.h" 3 | 4 | ROSIMP(ros::Publisher*, Publisher, clone)(ros::Publisher *self) { 5 | return new ros::Publisher(*self); 6 | } 7 | 8 | ROSIMP(void, Publisher, delete)(ros::Publisher *self) { 9 | delete self; 10 | } 11 | 12 | ROSIMP(void, Publisher, shutdown)(ros::Publisher *self) { 13 | self->shutdown(); 14 | } 15 | 16 | ROSIMP(void, Publisher, getTopic)(ros::Publisher *self, std::string *output) { 17 | *output = self->getTopic(); 18 | } 19 | 20 | ROSIMP(int, Publisher, getNumSubscribers)(ros::Publisher *self) { 21 | return static_cast(self->getNumSubscribers()); 22 | } 23 | 24 | ROSIMP(bool, Publisher, isLatched)(ros::Publisher *self) { 25 | return self->isLatched(); 26 | } 27 | 28 | ROSIMP(void, Publisher, publish)(ros::Publisher *self, THByteStorage *serialized_msg, ptrdiff_t offset, size_t length) { 29 | RawMessage msg; 30 | long storage_size = THByteStorage_size(serialized_msg); 31 | if (offset + length > static_cast(storage_size) || storage_size < 0) 32 | throw std::range_error("Specified array segment lies outside buffer."); 33 | msg.copyFrom(THByteStorage_data(serialized_msg) + offset, length); 34 | self->publish(msg); 35 | } 36 | -------------------------------------------------------------------------------- /demo/action_server.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | require 'ros.actionlib.ActionServer' 3 | local actionlib = ros.actionlib 4 | 5 | 6 | local function ActionServer_Goal(goal_handle) 7 | ros.INFO("ActionServer_Goal") 8 | local g = goal_handle:getGoal() 9 | print(g) 10 | goal_handle:setAccepted('yip') 11 | 12 | local r = goal_handle:createResult() 13 | r.result = 123 14 | print(r) 15 | --goal_handle:setAborted(r, 'no') 16 | goal_handle:setSucceeded(r, 'done') 17 | end 18 | 19 | 20 | local function ActionServer_Cancel(goal_handle) 21 | ros.INFO("ActionServer_Cancel") 22 | goal_handle:setCanceled(nil, 'blub') 23 | end 24 | 25 | 26 | local function testActionServer() 27 | ros.init('testActionServer') 28 | nh = ros.NodeHandle() 29 | ros.console.setLoggerLevel('actionlib', ros.console.Level.Debug) 30 | 31 | local as = actionlib.ActionServer(nh, 'test_action', 'actionlib/Test') 32 | 33 | as:registerGoalCallback(ActionServer_Goal) 34 | as:registerCancelCallback(ActionServer_Cancel) 35 | 36 | print('Starting action server...') 37 | as:start() 38 | 39 | while ros.ok() do 40 | ros.spinOnce() 41 | sys.sleep(0.01) 42 | end 43 | 44 | as:shutdown() 45 | nh:shutdown() 46 | ros.shutdown() 47 | end 48 | 49 | 50 | testActionServer() 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-ros 2 | 3 | Torch7/lua wrapper of roscpp via ffi. It offers dynamic serialization of ROS message without requiring pre-generated message classes. 4 | 5 | ## Currently supported feature set: 6 | 7 | - dynamic interpretation of ROS messages specifications (MsgSpec) 8 | - serialization/deserialization of ROS messages 9 | - representation of certain array types (byte, char, short, int, long, float, double) in messages as torch.Tensors 10 | - publishing and subscribing to ROS topics 11 | - basic parameter server support (NodeHandle:getParam*()) 12 | - basic ROS time support 13 | - basic ROS-console/logging support (e.g. ros.WARN(), ros.INFO() etc.) 14 | - service calls (advertiseService & serviceClient) 15 | - ActionLib (ActionClient/ActionServer & SimpleActionClient/SimpleActionServer) 16 | - TF library (Transform, StampedTransform, Quaternion, TransformListener, TransformBroadcaster) 17 | - wrappers for basic std::string, std::vector<std::string> 18 | 19 | ## Todo: 20 | 21 | - C++ exception handling/translation to lua 22 | 23 | ## Limitations 24 | 25 | Currently ony little-endian systems are supported. 26 | 27 | ## Based on roslua 28 | 29 | This project uses fragments of [roslua](https://github.com/timn/roslua) which was developed by Tim Niemueller @timn. 30 | -------------------------------------------------------------------------------- /src/ros/message_buffer.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include "../std/torch-std.h" 3 | #include "message_buffer.h" 4 | 5 | #define MessageBuffer_ptr boost::shared_ptr 6 | 7 | ROSIMP(MessageBuffer_ptr*, MessageBuffer, new)(int max_backlog) { 8 | return new MessageBuffer_ptr(new MessageBuffer(max_backlog)); 9 | } 10 | 11 | ROSIMP(void, MessageBuffer, delete)(MessageBuffer_ptr *self) { 12 | delete self; 13 | } 14 | 15 | ROSIMP(int, MessageBuffer, count)(MessageBuffer_ptr *self) { 16 | return static_cast((*self)->count()); 17 | } 18 | 19 | ROSIMP(void, MessageBuffer, clear)(MessageBuffer_ptr *self) { 20 | (*self)->clear(); 21 | } 22 | 23 | ROSIMP(bool, MessageBuffer, read)(MessageBuffer_ptr *self, int timeout_milliseconds, THByteStorage *msg_output, StringMap *header_output) { 24 | boost::shared_ptr buffer = (*self)->read(timeout_milliseconds); 25 | if (!buffer) 26 | return false; 27 | 28 | // copy message to output byte storage 29 | if (msg_output) { 30 | THByteStorage_resize(msg_output, buffer->get_length()); 31 | uint8_t* dst = THByteStorage_data(msg_output); 32 | memcpy(dst, buffer->get_buffer().get(), buffer->get_length()); 33 | } 34 | 35 | if (header_output) { 36 | *header_output = buffer->get_header(); 37 | } 38 | 39 | return true; 40 | } 41 | -------------------------------------------------------------------------------- /lua/init.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros.env' 2 | 3 | -- std 4 | require 'ros.std.String' 5 | require 'ros.std.StringVector' 6 | require 'ros.std.StringMap' 7 | require 'ros.std.Variable' 8 | require 'ros.std.VariableVector' 9 | require 'ros.std.VariableTable' 10 | require 'ros.std.Task' 11 | 12 | -- ros 13 | require 'ros.ros' 14 | require 'ros.Time' 15 | require 'ros.Duration' 16 | require 'ros.Rate' 17 | require 'ros.console' 18 | require 'ros.master' 19 | require 'ros.this_node' 20 | require 'ros.StorageWriter' 21 | require 'ros.StorageReader' 22 | require 'ros.CallbackQueue' 23 | require 'ros.MsgSpec' 24 | require 'ros.SrvSpec' 25 | require 'ros.Message' 26 | require 'ros.AsyncSpinner' 27 | require 'ros.MessageBuffer' 28 | require 'ros.SerializedMessage' 29 | require 'ros.Subscriber' 30 | require 'ros.Publisher' 31 | require 'ros.ServiceClient' 32 | require 'ros.ServiceServer' 33 | require 'ros.NodeHandle' 34 | 35 | -- tf 36 | require 'ros.tf.Quaternion' 37 | require 'ros.tf.Transform' 38 | require 'ros.tf.StampedTransform' 39 | require 'ros.tf.TransformBroadcaster' 40 | require 'ros.tf.TransformListener' 41 | 42 | -- actionlib 43 | require 'ros.actionlib.ActionSpec' 44 | require 'ros.actionlib.SimpleActionServer' 45 | require 'ros.actionlib.SimpleActionClient' 46 | require 'ros.actionlib.ServerGoalHandle' 47 | 48 | return ros 49 | -------------------------------------------------------------------------------- /demo/simple_action_server.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | require 'ros.actionlib.SimpleActionServer' 3 | local actionlib = ros.actionlib 4 | 5 | 6 | local function SimpleActionServer_onGoal(as) 7 | ros.INFO("SimpleActionServer_onGoal") 8 | 9 | local g = as:acceptNewGoal() 10 | print(g) 11 | 12 | assert(as:isActive()) -- ensure goal is active 13 | 14 | local r = as:createResult() 15 | r.result = 123 16 | print(r) 17 | --as:setAborted(r, 'no') 18 | as:setSucceeded(r, 'done') 19 | end 20 | 21 | 22 | local function SimpleActionServer_onPreempt(as) 23 | ros.INFO("SimpleActionServer_onPreempt") 24 | as:setPreempted(nil, 'blub') 25 | end 26 | 27 | 28 | local function testSimpleActionServer() 29 | ros.init('test_action_server') 30 | ros.console.setLoggerLevel('actionlib', ros.console.Level.Debug) 31 | nh = ros.NodeHandle() 32 | 33 | local as = actionlib.SimpleActionServer(nh, 'test_action', 'actionlib/Test') 34 | 35 | as:registerGoalCallback(SimpleActionServer_onGoal) 36 | as:registerPreemptCallback(SimpleActionServer_onPreempt) 37 | 38 | print('Starting action server...') 39 | as:start() 40 | 41 | while ros.ok() do 42 | ros.spinOnce() 43 | sys.sleep(0.01) 44 | end 45 | 46 | as:shutdown() 47 | nh:shutdown() 48 | ros.shutdown() 49 | end 50 | 51 | 52 | testSimpleActionServer() 53 | -------------------------------------------------------------------------------- /demo/advertiseService.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Advertise a service from torch-ros. 3 | The existing roscpp/GetLoggers srv file ist used. 4 | ]] 5 | 6 | local ros = require 'ros' 7 | 8 | ros.init('advertiseService_demo') 9 | nh = ros.NodeHandle() 10 | 11 | service_queue = ros.CallbackQueue() 12 | 13 | srv_spec = ros.SrvSpec('roscpp/GetLoggers') 14 | print(srv_spec) 15 | 16 | function myServiceHandler(request, response, header) 17 | print('[!] handler call') 18 | print('request:') 19 | print(request) 20 | print('header:') 21 | print(header) 22 | 23 | for k,v in pairs(ros.console.Level) do 24 | local l = ros.Message('roscpp/Logger') 25 | l.name = 'dummyname' .. v 26 | l.level = k 27 | table.insert(response.loggers, l) 28 | end 29 | 30 | print('response:') 31 | print(response) 32 | 33 | return true 34 | end 35 | 36 | server = nh:advertiseService('/demo_service', srv_spec, myServiceHandler, service_queue) 37 | print('name: ' .. server:getService()) 38 | print('service server running, call "rosservice call /demo_service" to send a request to the service.') 39 | 40 | local s = ros.Duration(0.001) 41 | while ros.ok() do 42 | s:sleep() 43 | if not service_queue:isEmpty() then 44 | print('[!] incoming service call') 45 | service_queue:callAvailable() 46 | end 47 | ros.spinOnce() 48 | end 49 | 50 | server:shutdown() 51 | ros.shutdown() 52 | -------------------------------------------------------------------------------- /lua/ServiceServer.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local ServiceServer = torch.class('ros.ServiceServer', ros) 8 | local ServiceServer_ptr_ct = ffi.typeof('ros_ServiceServer *') 9 | 10 | function init() 11 | local ServiceServer_method_names = { 12 | "delete", 13 | "shutdown", 14 | "getService" 15 | } 16 | 17 | return utils.create_method_table("ros_ServiceServer_", ServiceServer_method_names) 18 | end 19 | 20 | local f = init() 21 | 22 | function ServiceServer:__init(ptr, callback, service_handler_func) 23 | if not ffi.istype(ServiceServer_ptr_ct, ptr) then 24 | error('ros::ServiceServer* expected.') 25 | end 26 | 27 | self.o = ptr 28 | self.callback = callback 29 | self.handler = service_handler_func 30 | ffi.gc(ptr, 31 | function(p) 32 | f.delete(p) 33 | if self.callback ~= nil then 34 | self.callback:free() -- free callback 35 | self.callback = nil 36 | end 37 | end 38 | ) 39 | 40 | end 41 | 42 | function ServiceServer:cdata() 43 | return self.o 44 | end 45 | 46 | function ServiceServer:shutdown() 47 | f.shutdown(self.o) 48 | self.o = nil 49 | end 50 | 51 | function ServiceServer:getService() 52 | local result = std.String() 53 | f.getService(self.o, result:cdata()) 54 | return result:get() 55 | end 56 | -------------------------------------------------------------------------------- /src/tf/transform.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-tf.h" 2 | 3 | TFIMP(tf::Transform *, Transform, new)() { 4 | return new tf::Transform(); 5 | } 6 | 7 | TFIMP(tf::Transform *, Transform, clone)(tf::Transform *self) { 8 | return new tf::Transform(*self); 9 | } 10 | 11 | TFIMP(void, Transform, delete)(tf::Transform *self) { 12 | delete self; 13 | } 14 | 15 | TFIMP(void, Transform, setIdentity)(tf::Transform *self) { 16 | self->setIdentity(); 17 | } 18 | 19 | TFIMP(void, Transform, mul_Quaternion)(tf::Transform *self, tf::Quaternion *rot, tf::Quaternion *result) { 20 | *result = self->operator*(*rot); 21 | } 22 | 23 | TFIMP(void, Transform, mul_Transform)(tf::Transform *self, tf::Transform *other, tf::Transform *result) { 24 | *result = self->operator*(*other); 25 | } 26 | 27 | TFIMP(void, Transform, inverse)(tf::Transform *self, tf::Transform *result) { 28 | *result = self->inverse(); 29 | } 30 | 31 | TFIMP(void, Transform, getBasis)(tf::Transform *self, THDoubleTensor *basis) { 32 | viewMatrix3x3(self->getBasis(), basis); 33 | } 34 | 35 | TFIMP(void, Transform, getOrigin)(tf::Transform *self, THDoubleTensor *origin) { 36 | viewVector3(self->getOrigin(), origin); 37 | } 38 | 39 | TFIMP(void, Transform, setRotation)(tf::Transform *self, tf::Quaternion *rotation) { 40 | self->setRotation(*rotation); 41 | } 42 | 43 | TFIMP(void, Transform, getRotation)(tf::Transform *self, tf::Quaternion *rotation) { 44 | *rotation = self->getRotation(); 45 | } 46 | -------------------------------------------------------------------------------- /lua/PointCloud2SerializationHandler.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local pcl = require 'pcl' 6 | local std = ros.std 7 | 8 | local PointCloud2SerializationHandler = torch.class('ros.PointCloud2SerializationHandler', ros) 9 | 10 | function init() 11 | local names = { 12 | 'readPointCloud2', 13 | 'writePointCloud2' 14 | } 15 | 16 | return utils.create_method_table("ros_pcl_", names) 17 | end 18 | 19 | local f = init() 20 | 21 | function PointCloud2SerializationHandler:init() 22 | end 23 | 24 | function PointCloud2SerializationHandler:getType() 25 | return "sensor_msgs/PointCloud2" 26 | end 27 | 28 | function PointCloud2SerializationHandler:read(sr, value) 29 | -- call deserialization function 30 | value = value or pcl.PCLPointCloud2() 31 | local newOffset = f.readPointCloud2(sr.storage:cdata(), sr.offset, value:cdata()) 32 | sr:setOffset(newOffset) 33 | return value 34 | end 35 | 36 | function PointCloud2SerializationHandler:write(sw, value) 37 | if torch.isTypeOf(value, pcl.PointCloud) then 38 | value = value:toPCLPointCloud2() 39 | end 40 | 41 | if not torch.isTypeOf(value, pcl.PCLPointCloud2) then 42 | error("Invalid value type. 'pcl.PCLPointCloud2' expected.") 43 | end 44 | 45 | -- call serialization function 46 | local newOffset = f.writePointCloud2(sw.storage:cdata(), sw.offset, utils.cdata(value)) 47 | sw:storageChanged(newOffset) 48 | end 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Xamla 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of torch-ros nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /demo/service_client.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | 3 | ros.init('service_client_demo') 4 | 5 | local nh = ros.NodeHandle() 6 | 7 | 8 | -- 1. Call a simple service with empty request message 9 | 10 | local clientA = nh:serviceClient('/rosout/get_loggers', 'roscpp/GetLoggers') 11 | 12 | -- we can check if the service exists 13 | local ok = clientA:exists() 14 | print('exists() returned: ' .. tostring(ok)) 15 | 16 | -- or wait for it to become available 17 | local timeout = ros.Duration(5) 18 | local ok = clientA:waitForExistence(timeout) 19 | print('waitForExistence() returned: ' .. tostring(ok)) 20 | 21 | 22 | print('Calling service: ' .. clientA:getService()) 23 | 24 | -- call the service 25 | local response = clientA:call() 26 | 27 | print('Response:') 28 | print(response) 29 | 30 | 31 | -- 2. Call a service with a non-empty request. 32 | 33 | local clientB = ros.ServiceClient('/rosout/set_logger_level', 'roscpp/SetLoggerLevel') 34 | 35 | print('Service spec:') 36 | print(clientB.spec) 37 | 38 | -- we can either create the request message explicitely 39 | local req_msg = clientB:createRequest() 40 | req_msg:fillFromTable({logger="my_dummy_logger", level="warn"}) 41 | 42 | print('Request:') 43 | print(req_msg) 44 | 45 | print('Calling service: ' .. clientB:getService()) 46 | 47 | -- call the service 48 | response = clientB:call(req_msg) 49 | 50 | print('Response:') 51 | print(response) 52 | 53 | -- or let `call()` internally call fillFromTable() for us... 54 | response = clientB:call{logger="my_dummy_logger", level="warn"} 55 | 56 | 57 | ros.shutdown() 58 | -------------------------------------------------------------------------------- /src/std/string_map.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | 3 | STDIMP(StringMap *, StringMap, new)() { 4 | return new StringMap(); 5 | } 6 | 7 | STDIMP(StringMap *, StringMap, clone)(StringMap *self) { 8 | return new StringMap(*self); 9 | } 10 | 11 | STDIMP(void, StringMap, delete)(StringMap *ptr) { 12 | delete ptr; 13 | } 14 | 15 | STDIMP(int, StringMap, size)(StringMap *self) { 16 | return static_cast(self->size()); 17 | } 18 | 19 | STDIMP(void, StringMap, clear)(StringMap *self) { 20 | self->clear(); 21 | } 22 | 23 | STDIMP(const char *, StringMap, getAt)(StringMap *self, const char *key) { 24 | return (*self)[key].c_str(); 25 | } 26 | 27 | STDIMP(void, StringMap, setAt)(StringMap *self, const char *key, const char *value) { 28 | (*self)[key] = value; 29 | } 30 | 31 | STDIMP(bool, StringMap, insert)(StringMap *self, const char *key, const char *value) { 32 | return self->insert(StringMap::value_type(key, value)).second; 33 | } 34 | 35 | STDIMP(bool, StringMap, erase)(StringMap *self, const char *key) { 36 | return self->erase(key) > 0; 37 | } 38 | 39 | STDIMP(bool, StringMap, exists)(StringMap *self, const char *key) { 40 | return self->count(key) > 0; 41 | } 42 | 43 | STDIMP(void, StringMap, keys)(StringMap *self, StringVector *result) { 44 | result->clear(); 45 | for (StringMap::const_iterator i = self->begin(); i != self->end(); ++i) { 46 | result->push_back(i->first); 47 | } 48 | } 49 | 50 | STDIMP(void, StringMap, values)(StringMap *self, StringVector *result) { 51 | result->clear(); 52 | for (StringMap::const_iterator i = self->begin(); i != self->end(); ++i) { 53 | result->push_back(i->second); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /lua/AsyncSpinner.lua: -------------------------------------------------------------------------------- 1 | --- Wrapper for the ros::AsyncSpinner class. 2 | -- This spinner spins asynchronously when you call start(), and stops 3 | -- when either you call stop(), ros::shutdown() is called, or its 4 | -- destructor is called AsyncSpinner is reference counted internally, 5 | -- so if you copy one it will continue spinning until all copies have 6 | -- destructed (or stop() has been called on one of them) 7 | -- @classmod AsyncSpinner 8 | 9 | local ffi = require 'ffi' 10 | local torch = require 'torch' 11 | local ros = require 'ros.env' 12 | local utils = require 'ros.utils' 13 | 14 | local AsyncSpinner = torch.class('ros.AsyncSpinner', ros) 15 | 16 | function init() 17 | local Duration_method_names = { 18 | "new", 19 | "delete", 20 | "canStart", 21 | "start", 22 | "stop" 23 | } 24 | 25 | return utils.create_method_table("ros_AsyncSpinner_", Duration_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | --- Constructor 31 | -- @tparam[opt=0] int thread_count The number of threads to use. A value of 0 means to use the number of processor cores. 32 | function AsyncSpinner:__init(thread_count) 33 | self.o = f.new(thread_count or 0) 34 | end 35 | 36 | --- Check if the spinner can be started. 37 | -- A spinner can't be started if another spinner is already running. 38 | -- @treturn bool true if the spinner could be started, false otherwise 39 | function AsyncSpinner:canStart() 40 | return f.canStart(self.o) 41 | end 42 | 43 | --- Start this spinner spinning asynchronously. 44 | function AsyncSpinner:start() 45 | f.start(self.o) 46 | end 47 | 48 | --- Stop this spinner from running. 49 | function AsyncSpinner:stop() 50 | f.stop(self.o) 51 | end 52 | -------------------------------------------------------------------------------- /demo/action_client.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | require 'ros.actionlib.SimpleActionClient' 3 | local actionlib = ros.actionlib 4 | local SimpleClientGoalState = actionlib.SimpleClientGoalState 5 | 6 | ros.init('test_action_client') 7 | nh = ros.NodeHandle() 8 | ros.console.setLoggerLevel('actionlib', ros.console.Level.Debug) 9 | 10 | 11 | local ac = actionlib.SimpleActionClient('actionlib/Test', 'test_action', nh) 12 | 13 | 14 | function test_sync_api() 15 | local g = ac:createGoal() 16 | g.goal = 123 17 | local state = ac:sendGoalAndWait(g, 5, 5) 18 | ros.INFO('Finished with states: %s (%d)', SimpleClientGoalState[state], state) 19 | local result = ac:getResult() 20 | ros.INFO('Result:\n%s', result) 21 | end 22 | 23 | 24 | function test_async_api() 25 | -- test async api 26 | local done = false 27 | 28 | function Action_done(state, result) 29 | ros.INFO('Action_done') 30 | ros.INFO('Finished with states: %s (%d)', SimpleClientGoalState[state], state) 31 | ros.INFO('Result:\n%s', result) 32 | done = true 33 | end 34 | 35 | function Action_active() 36 | ros.INFO('Action_active') 37 | end 38 | 39 | function Action_feedback(feedback) 40 | ros.INFO('Action_feedback') 41 | end 42 | 43 | local g2 = ac:createGoal() 44 | g2.goal = 456 45 | ac:sendGoal(g2, Action_done, Action_active, Action_feedback) 46 | 47 | while ros.ok() and not done do 48 | ros.spinOnce() 49 | end 50 | end 51 | 52 | 53 | print('waiting for server connection...') 54 | if ac:waitForServer(ros.Duration(5.0)) then 55 | print('connected.') 56 | 57 | test_sync_api() 58 | test_async_api() 59 | 60 | else 61 | print('failed.') 62 | end 63 | 64 | 65 | ac:shutdown() 66 | nh:shutdown() 67 | ros.shutdown() 68 | -------------------------------------------------------------------------------- /lua/utils.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local ros = require 'ros.env' 3 | 4 | local utils = {} 5 | 6 | function utils.create_method_table(prefix, names) 7 | local map = {} 8 | for i,n in ipairs(names) do 9 | local full_name = prefix .. n 10 | -- use pcall since not all types support all functions 11 | local ok,v = pcall(function() return ros.lib[full_name] end) 12 | if ok then 13 | map[n] = v 14 | end 15 | end 16 | 17 | -- check whether we have new and delete functions 18 | -- automatically register objects created by new with the gc 19 | local _new, _clone, _delete = map.new, map.clone, map.delete 20 | 21 | if _new and _delete then 22 | map.new = function(...) 23 | local obj = _new(...) 24 | ffi.gc(obj, _delete) 25 | return obj 26 | end 27 | end 28 | 29 | if _clone and _delete then 30 | map.clone = function(...) 31 | local obj = _clone(...) 32 | ffi.gc(obj, _delete) 33 | return obj 34 | end 35 | end 36 | 37 | return map 38 | end 39 | 40 | -- safe accessor for cdata() 41 | function utils.cdata(x) 42 | return x and x:cdata() or ffi.NULL 43 | end 44 | 45 | function utils.reverse_mapping(t, r) 46 | for k,v in pairs(t) do 47 | r[v] = k 48 | end 49 | return r 50 | end 51 | 52 | function utils.cloneList(l) 53 | local c = {} 54 | for i,x in ipairs(l) do 55 | c[#c+1] = x 56 | end 57 | return c 58 | end 59 | 60 | function utils.indexOf(t, v) 61 | for i,x in ipairs(t) do 62 | if v == x then 63 | return i 64 | end 65 | end 66 | return -1 67 | end 68 | 69 | function utils.getTableKeys(t) 70 | local l = {} 71 | for k,v in pairs(t) do table.insert(l, k) end 72 | return l 73 | end 74 | 75 | return utils 76 | -------------------------------------------------------------------------------- /demo/publish_lena_image.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Publish an uncompressed image via torch-ros. 3 | This sample uses the torch image package as source for the Lena test image. 4 | To show the result simply start rviz and click 'Add' and select 'By topic'. 5 | In the treeview under topic 'lena_image' select 'Image'. A blinking Lena image 6 | will be displayed. The image is inverted at each time step to create a 7 | clearly visible change between consecutive frames. 8 | ]] 9 | 10 | ros = require 'ros' 11 | image = require 'image' 12 | 13 | ros.init('lena_image_source') 14 | nh = ros.NodeHandle() 15 | 16 | spinner = ros.AsyncSpinner() 17 | spinner:start() 18 | 19 | -- convert into byte rgb and publish 20 | local lena = image.lena() 21 | lena = lena:mul(255):permute(2,3,1):byte() 22 | 23 | publisher = nh:advertise("lena_image", 'sensor_msgs/Image', 10) 24 | 25 | local msg = ros.Message('sensor_msgs/Image') 26 | 27 | --[[ 28 | Message Definition of 'the sensor_msgs/Image' 29 | (for details see: http://docs.ros.org/api/sensor_msgs/html/msg/Image.html) 30 | 31 | std_msgs/Header header 32 | uint32 height 33 | uint32 width 34 | string encoding 35 | uint8 is_bigendian 36 | uint32 step 37 | uint8[] data 38 | ]] 39 | 40 | msg.height = lena:size(1) 41 | msg.width = lena:size(2) 42 | msg.encoding = "rgb8" 43 | msg.is_bigendian = false 44 | msg.step = lena:stride(1) 45 | msg.data = lena:reshape(msg.height * msg.width * 3) 46 | 47 | print('press ctrl+c to exit') 48 | while ros.ok() do 49 | if publisher:getNumSubscribers() > 0 then 50 | lena = -lena -- invert image to get some blinking effect 51 | msg.data = lena:reshape(msg.height * msg.width * 3) 52 | publisher:publish(msg) 53 | end 54 | sys.sleep(0.1) 55 | ros.spinOnce() 56 | end 57 | 58 | ros.shutdown() 59 | -------------------------------------------------------------------------------- /lua/actionlib/GoalStatus.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros.env' 2 | 3 | 4 | -- http://docs.ros.org/api/actionlib_msgs/html/msg/GoalStatus.html 5 | local GoalStatus = { 6 | PENDING = 0, -- The goal has yet to be processed by the action server 7 | ACTIVE = 1, -- The goal is currently being processed by the action server 8 | PREEMPTED = 2, -- The goal received a cancel request after it started executing 9 | -- and has since completed its execution (Terminal State) 10 | SUCCEEDED = 3, -- The goal was achieved successfully by the action server (Terminal State) 11 | ABORTED = 4, -- The goal was aborted during execution by the action server due 12 | -- to some failure (Terminal State) 13 | REJECTED = 5, -- The goal was rejected by the action server without being processed, 14 | -- because the goal was unattainable or invalid (Terminal State) 15 | PREEMPTING = 6, -- The goal received a cancel request after it started executing 16 | -- and has not yet completed execution 17 | RECALLING = 7, -- The goal received a cancel request before it started executing, 18 | -- but the action server has not yet confirmed that the goal is canceled 19 | RECALLED = 8, -- The goal received a cancel request before it started executing 20 | -- and was successfully cancelled (Terminal State) 21 | LOST = 9 -- An action client can determine that a goal is LOST. This should not be 22 | -- sent over the wire by an action server 23 | } 24 | ros.actionlib.GoalStatus = GoalStatus 25 | 26 | 27 | return GoalStatus 28 | -------------------------------------------------------------------------------- /src/ros/raw_message.h: -------------------------------------------------------------------------------- 1 | #ifndef raw_message_h 2 | #define raw_message_h 3 | 4 | class RawMessage { 5 | public: 6 | RawMessage() 7 | : buffer() 8 | , num_bytes(0) { 9 | } 10 | 11 | RawMessage(size_t length) 12 | : buffer(new uint8_t[length]) 13 | , num_bytes(length) { 14 | } 15 | 16 | void copyFrom(uint8_t *source, size_t length) { 17 | buffer = boost::shared_array(new uint8_t[length]); 18 | memcpy(buffer.get(), source, length); 19 | this->num_bytes = length; 20 | } 21 | 22 | size_t get_length() const { 23 | return num_bytes; 24 | } 25 | 26 | uint8_t *get() const { 27 | return buffer.get(); 28 | } 29 | 30 | ros::serialization::IStream get_IStream() const { 31 | return ros::serialization::IStream(buffer.get(), num_bytes); 32 | } 33 | 34 | ros::serialization::OStream get_OStream() const { 35 | return ros::serialization::OStream(buffer.get(), num_bytes); 36 | } 37 | 38 | const boost::shared_array& get_buffer() const { 39 | return buffer; 40 | } 41 | 42 | void set_header(const ros::M_string& header) { 43 | this->header = header; 44 | } 45 | 46 | const ros::M_string& get_header() const { 47 | return header; 48 | } 49 | 50 | private: 51 | ros::M_string header; 52 | boost::shared_array buffer; 53 | size_t num_bytes; 54 | }; 55 | 56 | namespace ros { 57 | namespace serialization { 58 | 59 | template<> 60 | inline SerializedMessage serializeMessage(const RawMessage &message) 61 | { 62 | SerializedMessage sm(message.get_buffer(), message.get_length()); 63 | sm.message_start = sm.buf.get() + sizeof(uint32_t); 64 | return sm; 65 | } 66 | 67 | } // namespace serialization 68 | } // namespace ros 69 | 70 | #endif // raw_message_h 71 | -------------------------------------------------------------------------------- /src/std/string_vector.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | 3 | STDIMP(StringVector *, StringVector, new)() { 4 | return new StringVector(); 5 | } 6 | 7 | STDIMP(StringVector *, StringVector, clone)(StringVector *self) { 8 | return new StringVector(*self); 9 | } 10 | 11 | STDIMP(void, StringVector, delete)(StringVector *ptr) { 12 | delete ptr; 13 | } 14 | 15 | STDIMP(int, StringVector, size)(StringVector *self) { 16 | return static_cast(self->size()); 17 | } 18 | 19 | STDIMP(const char*, StringVector, getAt)(StringVector *self, size_t pos) { 20 | StringVector& v = *self; 21 | return v[pos].c_str(); 22 | } 23 | 24 | STDIMP(void, StringVector, setAt)(StringVector *self, size_t pos, const char *value) { 25 | StringVector& v = *self; 26 | v[pos] = value; 27 | } 28 | 29 | STDIMP(void, StringVector, push_back)(StringVector *self, const char *value) { 30 | self->push_back(value); 31 | } 32 | 33 | STDIMP(void, StringVector, pop_back)(StringVector *self) { 34 | self->pop_back(); 35 | } 36 | 37 | STDIMP(void, StringVector, clear)(StringVector *self) { 38 | self->clear(); 39 | } 40 | 41 | STDIMP(void, StringVector, insert)(StringVector *self, size_t pos, size_t n, const char *value) { 42 | StringVector& v = *self; 43 | StringVector::iterator i = pos >= v.size() ? v.end() : v.begin() + pos; 44 | v.insert(i, n, value); 45 | } 46 | 47 | STDIMP(void, StringVector, erase)(StringVector *self, size_t begin, size_t end) { 48 | if (begin >= end) 49 | return; 50 | 51 | StringVector& v = *self; 52 | StringVector::iterator b = begin >= v.size() ? v.end() : v.begin() + begin; 53 | StringVector::iterator e = end >= v.size() ? v.end() : v.begin() + end; 54 | v.erase(b, e); 55 | } 56 | 57 | STDIMP(bool, StringVector, empty)(StringVector *self) { 58 | return self->empty(); 59 | } 60 | -------------------------------------------------------------------------------- /src/ros/serialized_message.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include "../std/torch-std.h" 3 | #include "message_buffer.h" 4 | 5 | 6 | ROSIMP(ros::SerializedMessage*, SerializedMessage, new)() { 7 | return new ros::SerializedMessage(); 8 | } 9 | 10 | ROSIMP(void, SerializedMessage, delete)(ros::SerializedMessage *self) { 11 | delete self; 12 | } 13 | 14 | ROSIMP(void, SerializedMessage, view)(ros::SerializedMessage *self, THByteTensor *view) { 15 | if (self->buf.get() == NULL || self->num_bytes == 0) { 16 | THByteTensor_resize1d(view, 0); 17 | } else { 18 | // creae special storage that views into memory of serialized message object 19 | THByteStorage *storage = THByteStorage_newWithData(self->buf.get(), self->num_bytes); 20 | storage->flag = TH_STORAGE_REFCOUNTED; 21 | ptrdiff_t offset = self->message_start - self->buf.get(); 22 | THByteTensor_setStorage1d(view, storage, offset, self->num_bytes, 1); 23 | } 24 | } 25 | 26 | ROSIMP(int, SerializedMessage, size)(ros::SerializedMessage *self) { 27 | return static_cast(self->num_bytes); 28 | } 29 | 30 | ROSIMP(uint8_t*, SerializedMessage, data)(ros::SerializedMessage *self) { 31 | return self->buf.get(); 32 | } 33 | 34 | ROSIMP(void, SerializedMessage, resize)(ros::SerializedMessage *self, size_t new_size) { 35 | const size_t old_size = self->num_bytes; 36 | if (old_size == new_size) 37 | return; // nothing to do 38 | 39 | if (new_size > old_size) { 40 | // reallocate & copy exstinig data 41 | boost::shared_array old_buf(self->buf); 42 | boost::shared_array new_buf(new uint8_t[new_size]); 43 | memcpy(new_buf.get(), old_buf.get(), std::min(new_size, old_size)); 44 | self->buf = new_buf; 45 | self->message_start = new_buf.get() + (self->message_start - old_buf.get()); 46 | } else { 47 | // do not reallocate when shrinking 48 | } 49 | 50 | self->num_bytes = new_size; 51 | } 52 | -------------------------------------------------------------------------------- /src/ros/service_client.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include "../std/torch-std.h" 3 | #include "raw_message.h" 4 | 5 | ROSIMP(ros::ServiceClient *, ServiceClient, new)( 6 | const char *service_name, 7 | bool persistent, 8 | StringMap *header_values, 9 | const char *service_md5sum 10 | ) { 11 | static const StringMap empty_header; 12 | return new ros::ServiceClient(service_name, persistent, header_values ? *header_values : empty_header, service_md5sum); 13 | } 14 | 15 | ROSIMP(ros::ServiceClient *, ServiceClient, clone)(ros::ServiceClient *self) { 16 | return new ros::ServiceClient(*self); 17 | } 18 | 19 | ROSIMP(void, ServiceClient, delete)(ros::ServiceClient *ptr) { 20 | delete ptr; 21 | } 22 | 23 | ROSIMP(bool, ServiceClient, call)(ros::ServiceClient *self, THByteStorage *request_msg, ros::SerializedMessage *response_msg, const char *service_md5sum) { 24 | // fill request message from request_msg byte storage 25 | RawMessage msg; 26 | msg.copyFrom(THByteStorage_data(request_msg), THByteStorage_size(request_msg)); 27 | 28 | ros::SerializedMessage req = ros::serialization::serializeMessage(msg); 29 | bool result = self->call(req, *response_msg, service_md5sum); 30 | 31 | return result; 32 | } 33 | 34 | ROSIMP(bool, ServiceClient, isPersistent)(ros::ServiceClient *self) { 35 | return self->isPersistent(); 36 | } 37 | 38 | ROSIMP(void, ServiceClient, getService)(ros::ServiceClient *self, std::string *output) { 39 | *output = self->getService(); 40 | } 41 | 42 | ROSIMP(bool, ServiceClient, waitForExistence)(ros::ServiceClient *self, ros::Duration *timeout) { 43 | return self->waitForExistence(timeout != NULL ? *timeout : ros::Duration(-1)); 44 | } 45 | 46 | ROSIMP(bool, ServiceClient, exists)(ros::ServiceClient *self) { 47 | return self->exists(); 48 | } 49 | 50 | ROSIMP(void, ServiceClient, shutdown)(ros::ServiceClient *self) { 51 | self->shutdown(); 52 | } 53 | 54 | ROSIMP(bool, ServiceClient, isValid)(ros::ServiceClient *self) { 55 | return self->isValid(); 56 | } 57 | -------------------------------------------------------------------------------- /lua/tests/test_serialization.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | 3 | 4 | function base_test() 5 | local sw = ros.StorageWriter() 6 | sw:writeString('Hello, this is a string!') 7 | for i=1,100 do 8 | sw:writeInt16(i) 9 | end 10 | sw:writeFloat32(1.23) 11 | sw:writeFloat64(1.23) 12 | sw:writeTensor(torch.linspace(0,10,10)) 13 | 14 | local rw = ros.StorageReader(sw.storage) 15 | print(rw:readString()) 16 | for i=1,100 do 17 | print(rw:readInt16()) 18 | end 19 | 20 | print(rw:readFloat32()) 21 | print(rw:readFloat64()) 22 | 23 | print(rw:readDoubleTensor()) 24 | end 25 | 26 | 27 | local img_spec = ros.MsgSpec('sensor_msgs/Image') 28 | 29 | local msg = ros.Message(img_spec) 30 | msg.header.seq = 918273 31 | msg.width = 5 32 | msg.height = 6 33 | msg.step = 5 34 | msg.data = torch.range(2,60,2):byte() -- creates torch.ByteTensor with 30 elements 35 | 36 | v = msg:serialize() 37 | v:shrinkToFit() 38 | 39 | local msg2 = ros.Message(img_spec, true) 40 | msg2:deserialize(v.storage) 41 | 42 | print(msg2.spec) 43 | print(msg2) 44 | 45 | 46 | function testFixedSizeArray() 47 | local m = ros.Message('geometry_msgs/PoseWithCovariance') 48 | m.covariance[5] = 123 49 | local v = m:serialize() 50 | v:shrinkToFit() 51 | n = ros.Message('geometry_msgs/PoseWithCovariance') 52 | n:deserialize(v.storage) 53 | assert(n.covariance[5] == 123) 54 | 55 | local test_msg_definiton = [[Header header 56 | uint32[5] id 57 | string[4] names 58 | time[2] times 59 | float64 confidence 60 | ]] 61 | local s = ros.MsgSpec('test', test_msg_definiton) 62 | 63 | local now = ros.Time.now() 64 | m = ros.Message(s) 65 | m.id[1] = 1 66 | m.id[4] = 4 67 | m.names[2] = 'hallo' 68 | m.times[2] = now 69 | local v = m:serialize() 70 | v:shrinkToFit() 71 | x = ros.Message(s) 72 | x:deserialize(v.storage) 73 | assert(1 == x.id[1]) 74 | assert(4 == x.id[4]) 75 | assert('hallo' == x.names[2]) 76 | assert(now == x.times[2]) 77 | assert(x.times[1] == ros.Time(0)) 78 | end 79 | 80 | testFixedSizeArray() 81 | -------------------------------------------------------------------------------- /src/ros/duration.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(ros::Duration*, Duration, new)() { 4 | return new ros::Duration(); 5 | } 6 | 7 | ROSIMP(void, Duration, delete)(ros::Duration *self) { 8 | delete self; 9 | } 10 | 11 | ROSIMP(ros::Duration*, Duration, clone)(ros::Duration *self) { 12 | return new ros::Duration(*self); 13 | } 14 | 15 | ROSIMP(void, Duration, set)(ros::Duration *self, int sec, int nsec) { 16 | self->sec = sec; 17 | self->nsec = nsec; 18 | } 19 | 20 | ROSIMP(void, Duration, assign)(ros::Duration *self, ros::Duration *other) { 21 | *self = *other; 22 | } 23 | 24 | ROSIMP(int, Duration, get_sec)(ros::Duration *self) { 25 | return self->sec; 26 | } 27 | 28 | ROSIMP(void, Duration, set_sec)(ros::Duration *self, int sec) { 29 | self->sec = sec; 30 | } 31 | 32 | ROSIMP(int, Duration, get_nsec)(ros::Duration *self) { 33 | return self->nsec; 34 | } 35 | 36 | ROSIMP(void, Duration, set_nsec)(ros::Duration *self, int nsec) { 37 | self->nsec = nsec; 38 | } 39 | 40 | ROSIMP(void, Duration, add)(ros::Duration *self, ros::Duration *other, ros::Duration *result) { 41 | *result = *self + *other; 42 | } 43 | 44 | ROSIMP(void, Duration, sub)(ros::Duration *self, ros::Duration *other, ros::Duration *result) { 45 | *result = *self - *other; 46 | } 47 | 48 | ROSIMP(void, Duration, mul)(ros::Duration *self, double scale, ros::Duration *result) { 49 | *result = *self * scale; 50 | } 51 | 52 | ROSIMP(bool, Duration, eq)(ros::Duration *self, ros::Duration *other) { 53 | return *self == *other; 54 | } 55 | 56 | ROSIMP(bool, Duration, lt)(ros::Duration *self, ros::Duration *other) { 57 | return *self < *other; 58 | } 59 | 60 | ROSIMP(double, Duration, toSec)(ros::Duration *self) { 61 | return self->toSec(); 62 | } 63 | 64 | ROSIMP(void, Duration, fromSec)(ros::Duration *self, double t) { 65 | self->fromSec(t); 66 | } 67 | 68 | ROSIMP(bool, Duration, isZero)(ros::Duration *self) { 69 | return self->isZero(); 70 | } 71 | 72 | ROSIMP(void, Duration, sleep)(ros::Duration *self) { 73 | self->sleep(); 74 | } 75 | -------------------------------------------------------------------------------- /lua/SerializedMessage.lua: -------------------------------------------------------------------------------- 1 | --- SerializedMessage class 2 | -- @classmod SerializedMessage 3 | local ffi = require 'ffi' 4 | local torch = require 'torch' 5 | local ros = require 'ros.env' 6 | local utils = require 'ros.utils' 7 | local std = ros.std 8 | 9 | local SerializedMessage = torch.class('ros.SerializedMessage', ros) 10 | local SerializedMessage_ptr_ct = ffi.typeof('ros_SerializedMessage *') 11 | 12 | function init() 13 | local SerializedMessage_method_names = { 14 | "new", 15 | "delete", 16 | "view", 17 | "size", 18 | "data", 19 | "resize" 20 | } 21 | 22 | return utils.create_method_table("ros_SerializedMessage_", SerializedMessage_method_names) 23 | end 24 | 25 | local f = init() 26 | 27 | --- SerializedMessage constructor. 28 | function SerializedMessage:__init() 29 | self.o = f.new() 30 | end 31 | 32 | --- internal function do no use in normal client code 33 | function SerializedMessage.fromPtr(ptr) 34 | if not ffi.istype(SerializedMessage_ptr_ct, ptr) then 35 | error('ros::SerializedMessage* expected.') 36 | end 37 | local c = torch.factory('ros.SerializedMessage')() 38 | rawset(c, 'o', ptr) 39 | return c 40 | end 41 | 42 | --- Access underlying data structure 43 | -- @return The data structure 44 | function SerializedMessage:cdata() 45 | return self.o 46 | end 47 | 48 | --- Get the number of bytes in the buffer. 49 | -- @treturn int Number of bytes 50 | function SerializedMessage:size() 51 | return f.size(self.o) 52 | end 53 | 54 | --- Creates a tensor object directly looking into the internal memory buffer of the SerializeMessage object without copying. 55 | -- @return Tensor with storage pointing to the internal buffer. 56 | function SerializedMessage:view(output_tensor) 57 | output_tensor = output_tensor or torch.ByteTensor() 58 | f.view(self.o, output_tensor:cdata()) 59 | return output_tensor 60 | end 61 | 62 | function SerializedMessage:data() 63 | return f.data(self.o) 64 | end 65 | 66 | function SerializedMessage:resize(new_size) 67 | return f.resize(self.o, new_size) 68 | end 69 | -------------------------------------------------------------------------------- /src/ros/master.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include 3 | 4 | ROSIMP(bool, Master, execute)( 5 | const char *method, 6 | xamla::Variable *request, 7 | xamla::Variable *response, 8 | xamla::Variable *payload, 9 | bool wait_for_master) { 10 | 11 | XmlRpc::XmlRpcValue request_; 12 | VariableToXmlRpcValue(*request, request_); // copy request 13 | 14 | XmlRpc::XmlRpcValue response_, payload_; 15 | bool result = ros::master::execute(method, request_, response_, payload_, wait_for_master); 16 | 17 | XmlRpcValueToVariable(response_, *response); // copy response 18 | XmlRpcValueToVariable(payload_, *payload); 19 | return result; 20 | } 21 | 22 | ROSIMP(const char*, Master, getHost)() { 23 | return ros::master::getHost().c_str(); 24 | } 25 | 26 | ROSIMP(int, Master, getPort)() { 27 | return static_cast(ros::master::getPort()); 28 | } 29 | 30 | ROSIMP(const char*, Master, getURI)() { 31 | return ros::master::getURI().c_str(); 32 | } 33 | 34 | ROSIMP(bool, Master, check)() { 35 | return ros::master::check(); 36 | } 37 | 38 | ROSIMP(bool, Master, getTopics)(xamla::VariableTable_ptr *output) { 39 | if (!*output) { 40 | output->reset(new xamla::VariableTable()); 41 | } 42 | 43 | xamla::VariableTable& output_table = **output; 44 | 45 | // call to maste api 46 | ros::master::V_TopicInfo topics; 47 | bool result = ros::master::getTopics(topics); 48 | 49 | // store result in table 50 | ros::master::V_TopicInfo::const_iterator i = topics.begin(); 51 | for (; i != topics.end(); ++i) { 52 | const ros::master::TopicInfo& topic = *i; 53 | xamla::VariableTable_ptr t(new xamla::VariableTable()); 54 | (*t)["topic"] = topic.name; 55 | (*t)["datatype"] = topic.datatype; 56 | output_table[topic.name] = t; 57 | } 58 | return result; 59 | } 60 | 61 | ROSIMP(bool, Master, getNodes)(std::vector *output) { 62 | return ros::master::getNodes(*output); 63 | } 64 | 65 | ROSIMP(void, Master, setRetryTimeout)(int sec, int nsec) { 66 | ros::master::setRetryTimeout(ros::WallDuration(sec, nsec)); 67 | } 68 | -------------------------------------------------------------------------------- /demo/publish_multi_array.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros' 2 | 3 | local function printf(...) return print(string.format(...)) end 4 | 5 | ros.init('multi_array_demo') 6 | 7 | local spinner = ros.AsyncSpinner() 8 | spinner:start() 9 | 10 | local nodehandle = ros.NodeHandle() 11 | 12 | 13 | local publisher = nodehandle:advertise("float_tensor_source", 'std_msgs/Float64MultiArray', 10) 14 | local subscriber = nodehandle:subscribe("float_tensor_source", 'std_msgs/Float64MultiArray', 10) 15 | 16 | 17 | local specFloat64MultiArray = ros.MsgSpec('std_msgs/Float64MultiArray') 18 | 19 | local function tensorToMsg(tensor) 20 | local msg = ros.Message(specFloat64MultiArray) 21 | msg.data = tensor:reshape(tensor:nElement()) 22 | for i=1,tensor:dim() do 23 | local dim_desc = ros.Message('std_msgs/MultiArrayDimension') 24 | dim_desc.size = tensor:size(i) 25 | dim_desc.stride = tensor:stride(i) 26 | table.insert(msg.layout.dim, dim_desc) 27 | end 28 | return msg 29 | end 30 | 31 | 32 | local function msgToTensor(msg) 33 | -- TODO: add support non-continuous tensor (stride handling) 34 | local layout = msg.layout 35 | local dim_sizes = torch.LongStorage(#layout.dim) 36 | for i=1,#layout.dim do 37 | dim_sizes[i] = msg.layout.dim[i].size 38 | end 39 | return msg.data:reshape(dim_sizes) 40 | end 41 | 42 | 43 | local recv_seq = 1 44 | subscriber:registerCallback(function(msg, header) 45 | --print('Header:') 46 | --print(header) 47 | --print('Message:') 48 | --print(msg) 49 | local t = msgToTensor(msg) 50 | --print("Tensor size:") 51 | --print(t:size()) 52 | printf("[recv_seq: %d] Sum of all elements (received): %f", recv_seq, t:sum()) 53 | recv_seq = recv_seq + 1 54 | end) 55 | 56 | local send_seq = 1 57 | while ros.ok() do 58 | local source_data = torch.rand(4, 27, 13) 59 | local msg = tensorToMsg(source_data) 60 | --print("Sending message:") 61 | --print(msg) 62 | printf("[send_seq: %d] Sum of all elements (sent): %f", send_seq, source_data:sum()) 63 | send_seq = send_seq + 1 64 | publisher:publish(msg) 65 | ros.spinOnce() 66 | sys.sleep(0.1) 67 | end 68 | 69 | ros.shutdown() 70 | -------------------------------------------------------------------------------- /src/std/variable_vector.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | #include "variable.h" 3 | 4 | using namespace xamla; 5 | 6 | #define VariableVector_Ptr xamla::VariableVector_ptr 7 | 8 | STDIMP(VariableVector_Ptr *, VariableVector, new)() { 9 | return new VariableVector_Ptr(new VariableVector()); 10 | } 11 | 12 | STDIMP(void, VariableVector, delete)(VariableVector_Ptr *self) { 13 | delete self; 14 | } 15 | 16 | STDIMP(VariableVector_Ptr *, VariableVector, clone)(VariableVector_Ptr *self) { 17 | return new VariableVector_Ptr(new VariableVector(**self)); 18 | } 19 | 20 | STDIMP(int, VariableVector, size)(VariableVector_Ptr *self) { 21 | return static_cast((*self)->size()); 22 | } 23 | 24 | STDIMP(void, VariableVector, getAt)(VariableVector_Ptr *self, size_t pos, xamla::Variable *result) { 25 | VariableVector& v = **self; 26 | *result = v[pos]; 27 | } 28 | 29 | STDIMP(void, VariableVector, setAt)(VariableVector_Ptr *self, size_t pos, xamla::Variable *value) { 30 | VariableVector& v = **self; 31 | v[pos] = *value; 32 | } 33 | 34 | STDIMP(void, VariableVector, push_back)(VariableVector_Ptr *self, xamla::Variable *value) { 35 | (*self)->push_back(*value); 36 | } 37 | 38 | STDIMP(void, VariableVector, pop_back)(VariableVector_Ptr *self) { 39 | (*self)->pop_back(); 40 | } 41 | 42 | STDIMP(void, VariableVector, clear)(VariableVector_Ptr *self) { 43 | (*self)->clear(); 44 | } 45 | 46 | STDIMP(void, VariableVector, insert)(VariableVector_Ptr *self, size_t pos, size_t n, xamla::Variable *value) { 47 | VariableVector& v = **self; 48 | VariableVector::iterator i = pos >= v.size() ? v.end() : v.begin() + pos; 49 | v.insert(i, n, *value); 50 | } 51 | 52 | STDIMP(void, VariableVector, erase)(VariableVector_Ptr *self, size_t begin, size_t end) { 53 | if (begin >= end) 54 | return; 55 | 56 | VariableVector& v = **self; 57 | VariableVector::iterator b = begin >= v.size() ? v.end() : v.begin() + begin; 58 | VariableVector::iterator e = end >= v.size() ? v.end() : v.begin() + end; 59 | v.erase(b, e); 60 | } 61 | 62 | STDIMP(bool, VariableVector, empty)(VariableVector_Ptr *self) { 63 | return (*self)->empty(); 64 | } 65 | -------------------------------------------------------------------------------- /src/std/variable_table.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | #include "variable.h" 3 | 4 | using namespace xamla; 5 | 6 | #define VariableVector_Ptr xamla::VariableVector_ptr 7 | #define VariableTable_Ptr xamla::VariableTable_ptr 8 | 9 | STDIMP(VariableTable_Ptr *, VariableTable, new)() { 10 | return new VariableTable_Ptr(new VariableTable()); 11 | } 12 | 13 | STDIMP(void, VariableTable, delete)(VariableTable_Ptr *handle) { 14 | delete handle; 15 | } 16 | 17 | STDIMP(VariableTable_Ptr *, VariableTable, clone)(VariableTable_Ptr *self) { 18 | return new VariableTable_Ptr(new VariableTable(**self)); 19 | } 20 | 21 | STDIMP(int, VariableTable, size)(VariableTable_Ptr *self) { 22 | return (int)(*self)->size(); 23 | } 24 | 25 | STDIMP(void, VariableTable, clear)(VariableTable_Ptr *self) { 26 | (*self)->clear(); 27 | } 28 | 29 | STDIMP(bool, VariableTable, getField)(VariableTable_Ptr *self, const char *key, Variable *result) { 30 | VariableTable& t = **self; 31 | VariableTable::iterator i = t.find(key); 32 | if (i == t.end()) 33 | return false; 34 | 35 | *result = i->second; 36 | return true; 37 | } 38 | 39 | STDIMP(void, VariableTable, setField)(VariableTable_Ptr *self, const char *key, Variable *value) { 40 | VariableTable& t = **self; 41 | t[key] = *value; 42 | } 43 | 44 | STDIMP(bool, VariableTable, erase)(VariableTable_Ptr *self, const char *key) { 45 | return (*self)->erase(std::string(key)) == 1; 46 | } 47 | 48 | STDIMP(bool, VariableTable, exists)(VariableTable_Ptr *self, const char *key) { 49 | return (*self)->count(key) > 0; 50 | } 51 | 52 | STDIMP(void, VariableTable, keys)(VariableTable_Ptr *self, StringVector *result) { 53 | VariableTable& t = **self; 54 | result->clear(); 55 | for (VariableTable::const_iterator i = t.begin(); i != t.end(); ++i) { 56 | result->push_back(i->first); 57 | } 58 | } 59 | 60 | STDIMP(void, VariableTable, values)(VariableTable_Ptr *self, VariableVector_Ptr *result) { 61 | VariableTable& t = **self; 62 | VariableVector& r = **result; 63 | r.clear(); 64 | for (VariableTable::const_iterator i = t.begin(); i != t.end(); ++i) { 65 | r.push_back(i->second); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /lua/Rate.lua: -------------------------------------------------------------------------------- 1 | --- ROS data class to handle time intervals 2 | -- @classmod Rate 3 | 4 | local ffi = require 'ffi' 5 | local torch = require 'torch' 6 | local ros = require 'ros.env' 7 | local utils = require 'ros.utils' 8 | 9 | local Rate = torch.class('ros.Rate', ros) 10 | 11 | function init() 12 | local Rate_method_names = { 13 | 'new', 14 | 'clone', 15 | 'delete', 16 | 'reset', 17 | 'sleep', 18 | 'expectedCycleTime', 19 | 'cycleTime' 20 | } 21 | 22 | return utils.create_method_table('ros_Rate_', Rate_method_names) 23 | end 24 | 25 | local f = init() 26 | 27 | --- Construct a time rate 28 | -- @tparam double _1 frequency The desired rate to run at in Hz 29 | function Rate:__init(_1) 30 | self.o = f.new(_1) 31 | end 32 | 33 | --- Get the underlying data structure 34 | -- @return 35 | function Rate:cdata() 36 | return self.o 37 | end 38 | 39 | --- Creates a deep copy of the object 40 | -- @return A copy of the object 41 | function Rate:clone() 42 | local c = torch.factory('ros.Rate')() 43 | rawset(c, 'o', f.clone(self.o)) 44 | return c 45 | end 46 | 47 | ---Sets the start time for the rate to now. 48 | function Rate:reset() 49 | f.reset(self.o) 50 | end 51 | 52 | --- Sleeps for any leftover time in a cycle. Calculated from the last time sleep, reset, or the constructor was called. 53 | function Rate:sleep() 54 | f.sleep(self.o) 55 | end 56 | 57 | --- Get the expected cycle time -- one over the frequency passed in to the constructor. 58 | function Rate:expectedCycleTime(output) 59 | local output = output or ros.Duration() 60 | f.expectedCycleTime(self.o, output:cdata()) 61 | return output 62 | end 63 | 64 | --- Get the actual run time of a cycle from start to sleep. 65 | function Rate:cycleTime(output) 66 | local output = output or ros.Duration() 67 | f.cycleTime(self.o, output:cdata()) 68 | return output 69 | end 70 | 71 | function Rate:__tostring() 72 | return string.format( 73 | 'cycleTime: %fsec, expectedCycleTime; %fsec', 74 | self:cycleTime():toSec(), 75 | self:expectedCycleTime():toSec() 76 | ) 77 | end 78 | -------------------------------------------------------------------------------- /demo/service_throughput_test.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local ros = require 'ros' 3 | 4 | 5 | local capture_spec = ros.SrvSpec('ximea_msgs/Capture', [[ 6 | string[] serials # serial numbers of cameras to use (empty means all cameras) 7 | --- 8 | string[] serials # serial numbers of cameras 9 | sensor_msgs/Image[] images # image data 10 | ]]) 11 | 12 | 13 | local function server() 14 | ros.init('service_throughput_test_server') 15 | 16 | local nh = ros.NodeHandle() 17 | 18 | local imageMsg = ros.Message('sensor_msgs/Image') 19 | 20 | local function captureHandler(request, response, header) 21 | ros.INFO('incoming request') 22 | imageMsg.data = torch.ByteTensor(4096 * 4096) 23 | response.images[1] = imageMsg 24 | return true 25 | end 26 | 27 | nh:advertiseService('/image_source', capture_spec, captureHandler) 28 | ros.INFO('service registered') 29 | 30 | while ros.ok() do 31 | ros.spinOnce() 32 | end 33 | 34 | ros.shutdown() 35 | end 36 | 37 | 38 | local function client() 39 | ros.init('service_throughput_test_client') 40 | local nh = ros.NodeHandle() 41 | 42 | local svc_client = nh:serviceClient('/image_source', capture_spec, true) 43 | print('waiting for service to become available...') 44 | svc_client:waitForExistence() 45 | svc_client = nh:serviceClient('/image_source', capture_spec, true) 46 | 47 | local start_time = torch.tic() 48 | local data_received = 0 49 | for i=1,100 do 50 | ros.INFO('Calling: %d', i) 51 | local response = svc_client:call() 52 | data_received = data_received + response.images[1].data:storage():size() 53 | ros.INFO('Response received') 54 | end 55 | local elapsed = torch.toc(start_time) 56 | ros.INFO('Test took: %f', elapsed) 57 | ros.INFO('%f MB/s', data_received / elapsed / (1024 * 1024)) 58 | 59 | ros.shutdown() 60 | end 61 | 62 | 63 | local function main() 64 | local mode = arg[1] 65 | 66 | if mode == 'client' then 67 | client() 68 | elseif mode == 'server' then 69 | server() 70 | else 71 | print('Please specify \'client\' or \'server\' as command line argument.') 72 | end 73 | 74 | end 75 | 76 | 77 | main() 78 | -------------------------------------------------------------------------------- /lua/MessageBuffer.lua: -------------------------------------------------------------------------------- 1 | --- Message buffer class 2 | -- @classmod MessageBuffer 3 | local ffi = require 'ffi' 4 | local torch = require 'torch' 5 | local ros = require 'ros.env' 6 | local utils = require 'ros.utils' 7 | local std = ros.std 8 | 9 | local MessageBuffer = torch.class('ros.MessageBuffer', ros) 10 | 11 | function init() 12 | local MessageBuffer_method_names = { 13 | "new", 14 | "delete", 15 | "count", 16 | "clear", 17 | "read" 18 | } 19 | 20 | return utils.create_method_table("ros_MessageBuffer_", MessageBuffer_method_names) 21 | end 22 | 23 | local f = init() 24 | 25 | --- Message buffer constructor. 26 | -- @tparam[opt] int max_backlog Maximum buffer size, if not provided, buffer size is unlimited 27 | -- @treturn MessageBuffer The constructed object 28 | function MessageBuffer:__init(max_backlog) 29 | self.o = f.new(max_backlog or -1) 30 | end 31 | 32 | --- Access underlying data structure 33 | -- @return The data structure 34 | function MessageBuffer:cdata() 35 | return self.o 36 | end 37 | 38 | --- Get the number of messages in the buffer. 39 | -- @treturn int Number of messages 40 | function MessageBuffer:getCount() 41 | return f.count(self.o) 42 | end 43 | 44 | --- Clear the buffer and discard all messages 45 | function MessageBuffer:clear() 46 | f.clear(self.o) 47 | end 48 | 49 | --- Get the next message from buffer 50 | -- @tparam int timeout_milliseconds Timeout to wait for next message if the buffer is empty 51 | -- @tparam[opt] torch.ByteStorage result_msg If present, this object is used to store the message data, otherwise the required memory is allocated by this function 52 | -- @tparam[opt] std.StringMap result_header If present, this object is used to store the message header, otherwise the required memory is allocated by this function 53 | -- @treturn torch.ByteStorage result_msg The message data 54 | -- @treturn std.StringMap The message header data 55 | function MessageBuffer:read(timeout_milliseconds, result_msg, result_header) 56 | result_msg = result_msg or torch.ByteStorage() 57 | result_header = result_header or std.StringMap() 58 | if not f.read(self.o, timeout_milliseconds or 100, result_msg:cdata(), result_header:cdata()) then 59 | return nil, nil 60 | end 61 | return result_msg, result_header 62 | end 63 | -------------------------------------------------------------------------------- /src/ros/point_cloud2.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include 3 | #include 4 | 5 | // torch-pcl interop methods to read and write PointCloud2 messages. 6 | 7 | ROSIMP(int32_t, pcl, readPointCloud2)(THByteStorage *serialized_message, int32_t offset, pcl::PCLPointCloud2 *cloud) { 8 | // deserialize to sensor_msgs::PointCloud2 message 9 | long buffer_length = THByteStorage_size(serialized_message); 10 | if (offset < 0 || offset > buffer_length) 11 | throw RosWrapperException("Offset out of range"); 12 | 13 | uint8_t *buffer = THByteStorage_data(serialized_message); 14 | 15 | ros::serialization::IStream stream(buffer + offset, static_cast(buffer_length - offset)); 16 | sensor_msgs::PointCloud2 cloud_msg; 17 | ros::serialization::Serializer::read(stream, cloud_msg); 18 | 19 | // convert to pcl::PointCloud2 20 | pcl_conversions::toPCL(cloud_msg, *cloud); 21 | 22 | return static_cast(stream.getData() - buffer); // return new offset 23 | } 24 | 25 | ROSIMP(int32_t, pcl, writePointCloud2)(THByteStorage *serialized_message, int32_t offset, pcl::PCLPointCloud2 *cloud) { 26 | // convert to sensor_msgs:PointCloud2 27 | sensor_msgs::PointCloud2 cloud_msg; 28 | if (cloud != NULL) { 29 | pcl_conversions::fromPCL(*cloud, cloud_msg); 30 | } 31 | 32 | if (cloud_msg.header.stamp.isZero() && ros::Time::isValid()) { 33 | cloud_msg.header.stamp = ros::Time::now(); 34 | } 35 | 36 | // determine serialization length & resize output buffer 37 | uint32_t length = ros::serialization::serializationLength(cloud_msg); 38 | 39 | // check if buffer length is sufficient 40 | if (THByteStorage_size(serialized_message) < offset + length) { 41 | THByteStorage_resize(serialized_message, offset + length); 42 | } 43 | 44 | uint8_t *buffer = THByteStorage_data(serialized_message); 45 | 46 | // write message 47 | ros::serialization::OStream stream(buffer + offset, THByteStorage_size(serialized_message) - offset); 48 | ros::serialization::Serializer::write(stream, cloud_msg); 49 | 50 | pcl::PCLPointCloud2 dummy; 51 | ros_pcl_readPointCloud2(serialized_message, offset, &dummy); 52 | 53 | return static_cast(stream.getData() - buffer); // return new offset 54 | } 55 | -------------------------------------------------------------------------------- /src/ros/callback_queue.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include 3 | 4 | namespace xamla { 5 | 6 | class WaitableCallbackQueue : public ros::CallbackQueue { 7 | public: 8 | WaitableCallbackQueue(bool enabled = true) 9 | : ros::CallbackQueue(enabled) {} 10 | 11 | // only called from the single lua thread (TLS and pending calls not taken into account) 12 | bool waitCallAvailable(ros::WallDuration timeout) { 13 | boost::mutex::scoped_lock lock(mutex_); 14 | if (!enabled_) 15 | return false; 16 | 17 | if (callbacks_.empty() && !timeout.isZero()) 18 | condition_.wait_for(lock, boost::chrono::nanoseconds(timeout.toNSec())); 19 | 20 | return !callbacks_.empty() && enabled_; 21 | } 22 | }; 23 | 24 | } // namespace xamla 25 | 26 | 27 | ROSIMP(xamla::WaitableCallbackQueue *, CallbackQueue, new)(bool enabled) { 28 | return new xamla::WaitableCallbackQueue(enabled); 29 | } 30 | 31 | ROSIMP(void, CallbackQueue, delete)(xamla::WaitableCallbackQueue *self) { 32 | delete self; 33 | } 34 | 35 | ROSIMP(int, CallbackQueue, callOne)(xamla::WaitableCallbackQueue *self, ros::Duration *timeout) { 36 | return static_cast(self->callOne(timeout == NULL ? ros::WallDuration() : ros::WallDuration(timeout->sec, timeout->nsec))); 37 | } 38 | 39 | ROSIMP(void, CallbackQueue, callAvailable)(xamla::WaitableCallbackQueue *self, ros::Duration *timeout) { 40 | self->callAvailable(timeout == NULL ? ros::WallDuration() : ros::WallDuration(timeout->sec, timeout->nsec)); 41 | } 42 | 43 | ROSIMP(bool, CallbackQueue, waitCallAvailable)(xamla::WaitableCallbackQueue *self, ros::Duration *timeout) { 44 | return self->waitCallAvailable(timeout == NULL ? ros::WallDuration() : ros::WallDuration(timeout->sec, timeout->nsec)); 45 | } 46 | 47 | ROSIMP(bool, CallbackQueue, isEmpty)(xamla::WaitableCallbackQueue *self) { 48 | return self->isEmpty(); 49 | } 50 | 51 | ROSIMP(void, CallbackQueue, clear)(xamla::WaitableCallbackQueue *self) { 52 | self->clear(); 53 | } 54 | 55 | ROSIMP(void, CallbackQueue, enable)(xamla::WaitableCallbackQueue *self) { 56 | self->enable(); 57 | } 58 | 59 | ROSIMP(void, CallbackQueue, disable)(xamla::WaitableCallbackQueue *self) { 60 | self->disable(); 61 | } 62 | 63 | ROSIMP(bool, CallbackQueue, isEnabled)(xamla::WaitableCallbackQueue *self) { 64 | return self->isEnabled(); 65 | } 66 | -------------------------------------------------------------------------------- /src/tf/torch-tf.h: -------------------------------------------------------------------------------- 1 | #ifndef torch_tf_h 2 | #define torch_tf_h 3 | 4 | extern "C" { 5 | #include 6 | } 7 | 8 | #include 9 | #include 10 | 11 | #define TFIMP(return_type, class_name, name) extern "C" return_type TH_CONCAT_4(tf_, class_name, _, name) 12 | 13 | typedef boost::shared_ptr > StringsPtr; 14 | 15 | inline void viewMatrix3x3(tf::Matrix3x3& m, THDoubleTensor *t) { 16 | THDoubleStorage* storage = THDoubleStorage_newWithData(m[0].m_floats, sizeof(m) / sizeof(double)); 17 | THDoubleStorage_clearFlag(storage, TH_STORAGE_FREEMEM | TH_STORAGE_RESIZABLE); 18 | THDoubleTensor_setStorage2d(t, storage, 0, 3, sizeof(m[0]) / sizeof(double), 3, 1); 19 | THDoubleStorage_free(storage); // tensor took ownership 20 | } 21 | 22 | inline void viewVector3(tf::Vector3& v, THDoubleTensor *t) { 23 | THDoubleStorage* storage = THDoubleStorage_newWithData(v.m_floats, sizeof(v.m_floats) / sizeof(double)); 24 | THDoubleStorage_clearFlag(storage, TH_STORAGE_FREEMEM | TH_STORAGE_RESIZABLE); 25 | THDoubleTensor_setStorage1d(t, storage, 0, 3, 1); 26 | THDoubleStorage_free(storage); // tensor took ownership 27 | } 28 | 29 | inline void viewQuaternion(tf::Quaternion& q, THDoubleTensor *t) { 30 | THDoubleStorage* storage = THDoubleStorage_newWithData(static_cast(q), sizeof(q) / sizeof(double)); 31 | THDoubleStorage_clearFlag(storage, TH_STORAGE_FREEMEM | TH_STORAGE_RESIZABLE); 32 | THDoubleTensor_setStorage1d(t, storage, 0, 4, 1); 33 | THDoubleStorage_free(storage); // tensor took ownership 34 | } 35 | 36 | inline void copyVector3ToTensor(const tf::Vector3 &v, THDoubleTensor *tensor) { 37 | THDoubleTensor_resize1d(tensor, 3); 38 | THDoubleTensor* output_ = THDoubleTensor_newContiguous(tensor); 39 | double *data = THDoubleTensor_data(output_); 40 | data[0] = v.getX(); 41 | data[1] = v.getY(); 42 | data[2] = v.getZ(); 43 | THDoubleTensor_freeCopyTo(output_, tensor); 44 | } 45 | 46 | inline void copyTensorToVector3(THDoubleTensor *tensor, tf::Vector3 &v) { 47 | if (!tensor || THDoubleTensor_nElement(tensor) < 3) 48 | throw std::runtime_error("A Tensor with at least 3 elements was expected."); 49 | 50 | THDoubleTensor *tensor_ = THDoubleTensor_newContiguous(tensor); 51 | const double *data = THDoubleTensor_data(tensor_); 52 | v.setX(data[0]); 53 | v.setY(data[1]); 54 | v.setZ(data[2]); 55 | v.setW(0); 56 | THDoubleTensor_free(tensor_); 57 | } 58 | 59 | #endif // torch_tf_h 60 | -------------------------------------------------------------------------------- /src/std/variable.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-std.h" 2 | #include "variable.h" 3 | 4 | using namespace xamla; 5 | 6 | STDIMP(Variable *, Variable, new)() { 7 | return new Variable(); 8 | } 9 | 10 | STDIMP(Variable *, Variable, clone)(Variable *self) { 11 | return new Variable(*self); 12 | } 13 | 14 | STDIMP(void, Variable, delete)(Variable *ptr) { 15 | delete ptr; 16 | } 17 | 18 | STDIMP(int, Variable, get_type)(Variable *self) { 19 | return self->get_type(); 20 | } 21 | 22 | STDIMP(void, Variable, clear)(Variable *self) { 23 | self->clear(); 24 | } 25 | 26 | STDIMP(void, Variable, assign)(Variable *self, Variable *src) { 27 | *self = *src; 28 | } 29 | 30 | #define declare_getter(_name, _type) \ 31 | STDIMP(_type, Variable, get_##_name)(Variable *self) { \ 32 | return self->get_##_name(); \ 33 | } 34 | 35 | declare_getter(bool, bool) 36 | declare_getter(int8, int8_t) 37 | declare_getter(int16, int16_t) 38 | declare_getter(int32, int32_t) 39 | declare_getter(int64, int64_t) 40 | declare_getter(uint8, uint8_t) 41 | declare_getter(uint16, uint16_t) 42 | declare_getter(uint32, uint32_t) 43 | declare_getter(uint64, uint64_t) 44 | declare_getter(float32, float) 45 | declare_getter(float64, double) 46 | 47 | #undef declare_getter 48 | 49 | STDIMP(const char *, Variable, get_string)(Variable *self) { 50 | return self->get_string().c_str(); 51 | } 52 | 53 | STDIMP(void, Variable, get_vector)(Variable *self, VariableVector_ptr *result) { 54 | *result = self->get_vector(); 55 | } 56 | 57 | STDIMP(void, Variable, get_table)(Variable *self, VariableTable_ptr *result) { 58 | *result = self->get_table(); 59 | } 60 | 61 | #define declare_setter(_name, _type) \ 62 | STDIMP(void, Variable, set_##_name)(Variable *self, const _type value) { \ 63 | self->set_##_name(value); \ 64 | } 65 | 66 | declare_setter(bool, bool) 67 | declare_setter(int8, int8_t) 68 | declare_setter(int16, int16_t) 69 | declare_setter(int32, int32_t) 70 | declare_setter(int64, int64_t) 71 | declare_setter(uint8, uint8_t) 72 | declare_setter(uint16, uint16_t) 73 | declare_setter(uint32, uint32_t) 74 | declare_setter(uint64, uint64_t) 75 | declare_setter(float32, float) 76 | declare_setter(float64, double) 77 | 78 | #undef declare_setter 79 | 80 | STDIMP(void, Variable, set_string)(Variable *self, const char *value) { 81 | self->set_string(value); 82 | } 83 | 84 | STDIMP(void, Variable, set_vector)(Variable *self, VariableVector_ptr *value) { 85 | self->set_vector(*value); 86 | } 87 | 88 | STDIMP(void, Variable, set_table)(Variable *self, VariableTable_ptr *value) { 89 | self->set_table(*value); 90 | } 91 | -------------------------------------------------------------------------------- /demo/orbit.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | tf = ros.tf 3 | 4 | --[[ 5 | 6 | Compute points along an arc on a plane around a center point and normal. 7 | 8 | You can use rviz to display the moving 'eye' transform with the TF display. 9 | Please make sure 'Fixed Frame' is set to 'world'. 10 | 11 | ]] 12 | 13 | function normalize(v) 14 | return v / torch.norm(v) 15 | end 16 | 17 | function totensor(t) 18 | return not t or torch.isTensor(t) and t or torch.Tensor(t) 19 | end 20 | 21 | function pos_vector(v) 22 | if v:size(1) ~= 4 then 23 | v = torch.Tensor({ v[1], v[2], v[3], 1 }) 24 | end 25 | return v 26 | end 27 | 28 | function project_onto_plane(plane_point, plane_normal, pt) 29 | return pt - plane_normal * torch.dot(pt - plane_point, plane_normal) 30 | end 31 | 32 | function look_at_pose(eye, at, up) 33 | -- eye becomes origin, 'at' lies on z-axis 34 | local xaxis = normalize(at - eye) 35 | local yaxis = -normalize(torch.cross(xaxis, up)) 36 | local zaxis = torch.cross(xaxis, yaxis) 37 | 38 | local basis = torch.Tensor(3,3) 39 | basis[{{},{1}}] = xaxis 40 | basis[{{},{2}}] = yaxis 41 | basis[{{},{3}}] = zaxis 42 | 43 | local t = tf.Transform() 44 | t:setBasis(basis) 45 | t:setOrigin(eye) 46 | return t 47 | end 48 | 49 | function generate_arc(center, normal, start_pt, total_rotation_angle, angle_step, look_at, up) 50 | up = up or torch.Tensor({0,0,1}) 51 | look_at = look_at or center 52 | 53 | center = totensor(center) 54 | normal = totensor(normal) 55 | start_pt = totensor(start_pt) 56 | start_pt = project_onto_plane(center, normal, start_pt) 57 | look_at = totensor(look_at) 58 | up = totensor(up) 59 | 60 | local poses = {} 61 | local steps = math.max(math.floor(total_rotation_angle / angle_step + 0.5), 1) 62 | for i=0,steps do 63 | local theta = total_rotation_angle * i / steps 64 | 65 | local t = tf.Transform() 66 | t:setRotation(tf.Quaternion(normal, theta)) 67 | t:setOrigin(center) 68 | 69 | local eye = t:toTensor() * pos_vector(start_pt-center) 70 | local pose = look_at_pose(eye[{{1,3}}], look_at, up) 71 | table.insert(poses, pose) 72 | end 73 | 74 | return poses 75 | end 76 | 77 | x = generate_arc({1,1,0.25}, {0,0,1}, {0.5,1.5,1.25}, 2 * math.pi, 0.1, {1,1,0}) 78 | 79 | ros.init('lookat') 80 | ros.Time.init() 81 | 82 | local b = tf.TransformBroadcaster() 83 | 84 | local i = 1 85 | while ros.ok() do 86 | print(x[i]) 87 | local st = tf.StampedTransform(x[i], ros.Time.now(), 'world', 'eye') 88 | b:sendTransform(st) 89 | ros.Duration(0.1):sleep() 90 | i = (i % #x) + 1 91 | ros.spinOnce() 92 | end 93 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.11) 2 | 3 | project(rostorch) 4 | 5 | find_package(Torch REQUIRED) 6 | find_package(Boost 1.58.0 REQUIRED COMPONENTS system thread chrono) 7 | find_package(catkin REQUIRED COMPONENTS roscpp std_msgs) 8 | find_package(tf) 9 | #find_package(pcl_ros) 10 | 11 | set(SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src") 12 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fmax-errors=3 -Wall" ) 13 | 14 | set(src 15 | "${SOURCE_DIR}/std/string.cpp" 16 | "${SOURCE_DIR}/std/string_vector.cpp" 17 | "${SOURCE_DIR}/std/string_map.cpp" 18 | "${SOURCE_DIR}/std/variable.cpp" 19 | "${SOURCE_DIR}/std/variable_vector.cpp" 20 | "${SOURCE_DIR}/std/variable_table.cpp" 21 | 22 | "${SOURCE_DIR}/ros/init.cpp" 23 | "${SOURCE_DIR}/ros/callback_queue.cpp" 24 | "${SOURCE_DIR}/ros/node_handle.cpp" 25 | "${SOURCE_DIR}/ros/master.cpp" 26 | "${SOURCE_DIR}/ros/this_node.cpp" 27 | "${SOURCE_DIR}/ros/message_buffer.cpp" 28 | "${SOURCE_DIR}/ros/serialized_message.cpp" 29 | "${SOURCE_DIR}/ros/subscriber.cpp" 30 | "${SOURCE_DIR}/ros/publisher.cpp" 31 | "${SOURCE_DIR}/ros/service_client.cpp" 32 | "${SOURCE_DIR}/ros/service_server.cpp" 33 | "${SOURCE_DIR}/ros/async_spinner.cpp" 34 | "${SOURCE_DIR}/ros/time.cpp" 35 | "${SOURCE_DIR}/ros/duration.cpp" 36 | "${SOURCE_DIR}/ros/rate.cpp" 37 | "${SOURCE_DIR}/ros/console.cpp" 38 | ) 39 | 40 | # compile TF support if available 41 | if (${tf_FOUND}) 42 | set(src ${src} 43 | "${SOURCE_DIR}/tf/transform.cpp" 44 | "${SOURCE_DIR}/tf/quaternion.cpp" 45 | "${SOURCE_DIR}/tf/stamped_transform.cpp" 46 | "${SOURCE_DIR}/tf/transform_broadcaster.cpp" 47 | "${SOURCE_DIR}/tf/transform_listener.cpp" 48 | ) 49 | list(APPEND catkin_INCLUDE_DIRS ${tf_INCLUDE_DIRS}) 50 | list(APPEND catkin_LIBRARIES ${tf_LIBRARIES}) 51 | endif() 52 | 53 | # compile PCL support if available 54 | if (${pcl_ros_FOUND}) 55 | set(src ${src} 56 | "${SOURCE_DIR}/ros/point_cloud2.cpp" 57 | ) 58 | list(APPEND catkin_INCLUDE_DIRS ${pcl_ros_INCLUDE_DIRS}) 59 | list(APPEND catkin_LIBRARIES ${pcl_ros_LIBRARIES}) 60 | endif() 61 | 62 | include_directories( 63 | ${Boost_INCLUDE_DIRS} 64 | ${Torch_INSTALL_INCLUDE} 65 | ${catkin_INCLUDE_DIRS} 66 | ) 67 | 68 | link_directories( 69 | ${Torch_INSTALL_LIB} 70 | ${Boost_LIBRARY_DIRS} 71 | ) 72 | 73 | add_library(rostorch MODULE ${src}) 74 | 75 | target_link_libraries(rostorch TH ${catkin_LIBRARIES} ${Boost_LIBRARIES} ) 76 | 77 | install(TARGETS rostorch LIBRARY DESTINATION ${Torch_INSTALL_LUA_CPATH_SUBDIR}) 78 | install(DIRECTORY "lua/" DESTINATION "${Torch_INSTALL_LUA_PATH_SUBDIR}/ros" FILES_MATCHING PATTERN "*.lua") 79 | -------------------------------------------------------------------------------- /src/ros/time.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | 3 | ROSIMP(ros::Time*, Time, new)() { 4 | return new ros::Time(); 5 | } 6 | 7 | ROSIMP(void, Time, delete)(ros::Time *self) { 8 | delete self; 9 | } 10 | 11 | ROSIMP(ros::Time*, Time, clone)(ros::Time *self) { 12 | return new ros::Time(*self); 13 | } 14 | 15 | ROSIMP(bool, Time, isZero)(ros::Time *self) { 16 | return self->isZero(); 17 | } 18 | 19 | ROSIMP(void, Time, fromSec)(ros::Time *self, double t) { 20 | self->fromSec(t); 21 | } 22 | 23 | ROSIMP(double, Time, toSec)(ros::Time *self) { 24 | return self->toSec(); 25 | } 26 | 27 | ROSIMP(void, Time, set)(ros::Time *self, unsigned int sec, unsigned int nsec) { 28 | self->sec = sec; 29 | self->nsec = nsec; 30 | } 31 | 32 | ROSIMP(void, Time, assign)(ros::Time *self, ros::Time *other) { 33 | *self = *other; 34 | } 35 | 36 | ROSIMP(int, Time, get_sec)(ros::Time *self) { 37 | return static_cast(self->sec); 38 | } 39 | 40 | ROSIMP(void, Time, set_sec)(ros::Time *self, unsigned int sec) { 41 | self->sec = sec; 42 | } 43 | 44 | ROSIMP(int, Time, get_nsec)(ros::Time *self) { 45 | return static_cast(self->nsec); 46 | } 47 | 48 | ROSIMP(void, Time, set_nsec)(ros::Time *self, unsigned int nsec) { 49 | self->nsec = nsec; 50 | } 51 | 52 | ROSIMP(bool, Time, lt)(ros::Time *self, ros::Time *other) { 53 | return self->operator<(*other); 54 | } 55 | 56 | ROSIMP(bool, Time, eq)(ros::Time *self, ros::Time *other) { 57 | return self->operator==(*other); 58 | } 59 | 60 | ROSIMP(void, Time, add_Duration)(ros::Time *self, ros::Duration *duration, ros::Time *result) { 61 | *result = self->operator+(*duration); 62 | } 63 | 64 | ROSIMP(void, Time, sub)(ros::Time *self, ros::Time *other, ros::Duration *result) { 65 | *result = self->operator-(*other); 66 | } 67 | 68 | ROSIMP(void, Time, sub_Duration)(ros::Time *self, ros::Duration *duration, ros::Time *result) { 69 | *result = self->operator-(*duration); 70 | } 71 | 72 | // static members 73 | 74 | ROSIMP(void, Time, sleepUntil)(ros::Time *end) { 75 | ros::Time::sleepUntil(*end); 76 | } 77 | 78 | ROSIMP(void, Time, getNow)(ros::Time *result) { 79 | *result = ros::Time::now(); 80 | } 81 | 82 | ROSIMP(void, Time, setNow)(ros::Time* now) { 83 | ros::Time::setNow(*now); 84 | } 85 | 86 | ROSIMP(void, Time, waitForValid)() { 87 | ros::Time::waitForValid(); 88 | } 89 | 90 | ROSIMP(void, Time, init)() { 91 | ros::Time::init(); 92 | } 93 | 94 | ROSIMP(void, Time, shutdown)() { 95 | ros::Time::shutdown(); 96 | } 97 | 98 | ROSIMP(bool, Time, useSystemTime)() { 99 | return ros::Time::useSystemTime(); 100 | } 101 | 102 | ROSIMP(bool, Time, isSimTime)() { 103 | return ros::Time::isSimTime(); 104 | } 105 | 106 | ROSIMP(bool, Time, isSystemTime)() { 107 | return ros::Time::isSystemTime(); 108 | } 109 | 110 | ROSIMP(bool, Time, isValid)() { 111 | return ros::Time::isValid(); 112 | } 113 | -------------------------------------------------------------------------------- /lua/std/StringMap.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local StringMap = torch.class('std.StringMap', std) 8 | 9 | function init() 10 | local StringMap_method_names = { 11 | "new", 12 | "clone", 13 | "delete", 14 | "size", 15 | "clear", 16 | "getAt", 17 | "setAt", 18 | "insert", 19 | "erase", 20 | "exists", 21 | "keys", 22 | "values" 23 | } 24 | 25 | return utils.create_method_table("std_StringMap_", StringMap_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | function StringMap:__init(x) 31 | rawset(self, 'o', f.new()) 32 | if x ~= nil and type(x) == 'table' then 33 | self:insertFromTable(x) 34 | end 35 | end 36 | 37 | function StringMap:cdata() 38 | return self.o 39 | end 40 | 41 | function StringMap:clone() 42 | local c = torch.factory('std.StringMap')() 43 | rawset(c, 'o', f.clone(self.o)) 44 | return c 45 | end 46 | 47 | function StringMap:size() 48 | return f.size(self.o) 49 | end 50 | 51 | function StringMap:__len() 52 | return self:size() 53 | end 54 | 55 | function StringMap:clear() 56 | f.clear(self.o) 57 | end 58 | 59 | function StringMap:insert(key, value) 60 | local o = rawget(self, 'o') 61 | return f.insert(o, key, value) 62 | end 63 | 64 | function StringMap:insertFromTable(t) 65 | for k,v in pairs(t) do 66 | self:setAt(k, tostring(v)) 67 | end 68 | end 69 | 70 | function StringMap:erase(key) 71 | return f.erase(self.o, key) 72 | end 73 | 74 | function StringMap:getAt(key) 75 | return ffi.string(f.getAt(self.o, key)) 76 | end 77 | 78 | function StringMap:setAt(key, value) 79 | f.setAt(self.o, key, value) 80 | end 81 | 82 | function StringMap:__index(key) 83 | local v = rawget(self, key) 84 | if not v then 85 | v = StringMap[key] 86 | if not v and type(key) == 'string' then 87 | v = self:getAt(key) 88 | end 89 | end 90 | return v 91 | end 92 | 93 | function StringMap:__newindex(key, value) 94 | local o = rawget(self, 'o') 95 | if type(key) == 'string' then 96 | self:setAt(key, value) 97 | else 98 | rawset(self, key, value) 99 | end 100 | end 101 | 102 | function StringMap:keys() 103 | local v = std.StringVector() 104 | f.keys(self.o, v:cdata()) 105 | return v 106 | end 107 | 108 | function StringMap:values() 109 | local v = std.StringVector() 110 | f.values(self.o, v:cdata()) 111 | return v 112 | end 113 | 114 | function StringMap:totable() 115 | local k,v = self:keys(),self:values() 116 | local r = {} 117 | for i=1,#k do 118 | r[k[i]] = v[i] 119 | end 120 | return r 121 | end 122 | 123 | function StringMap:__tostring() 124 | local t = {} 125 | table.insert(t, '{') 126 | local k,v = self:keys(),self:values() 127 | for i=1,#k do 128 | table.insert(t, ' ' .. k[i] .. ' : ' .. v[i]) 129 | end 130 | table.insert(t, '}') 131 | table.insert(t, string.format('[%s]', torch.type(self))) 132 | return table.concat(t, '\n') 133 | end 134 | -------------------------------------------------------------------------------- /src/tf/stamped_transform.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-tf.h" 2 | 3 | TFIMP(tf::StampedTransform *, StampedTransform, new)( 4 | tf::Transform *transform, 5 | ros::Time* timestamp, 6 | const char *frame_id, 7 | const char *child_frame_id 8 | ) { 9 | return new tf::StampedTransform( 10 | transform ? *transform : tf::Transform::getIdentity(), 11 | timestamp ? *timestamp : ros::Time(), 12 | frame_id ? frame_id : "", 13 | child_frame_id ? child_frame_id : "" 14 | ); 15 | } 16 | 17 | TFIMP(tf::StampedTransform *, StampedTransform, clone)(tf::StampedTransform *self) { 18 | return new tf::StampedTransform(*self); 19 | } 20 | 21 | TFIMP(void, StampedTransform, delete)(tf::StampedTransform *self) { 22 | delete self; 23 | } 24 | 25 | TFIMP(tf::Transform *, StampedTransform, getBasePointer)(tf::StampedTransform *self) { 26 | return static_cast(self); 27 | } 28 | 29 | TFIMP(void, StampedTransform, get_stamp)(tf::StampedTransform *self, ros::Time *result) { 30 | *result = self->stamp_; 31 | } 32 | 33 | TFIMP(void, StampedTransform, set_stamp)(tf::StampedTransform *self, ros::Time *stamp) { 34 | self->stamp_ = *stamp; 35 | } 36 | 37 | TFIMP(const char *, StampedTransform, get_frame_id)(tf::StampedTransform *self) { 38 | return self->frame_id_.c_str(); 39 | } 40 | 41 | TFIMP(void, StampedTransform, set_frame_id)(tf::StampedTransform *self, const char *id) { 42 | self->frame_id_ = id; 43 | } 44 | 45 | TFIMP(const char *, StampedTransform, get_child_frame_id)(tf::StampedTransform *self) { 46 | return self->child_frame_id_.c_str(); 47 | } 48 | 49 | TFIMP(void, StampedTransform, set_child_frame_id)(tf::StampedTransform *self, const char *id) { 50 | self->child_frame_id_ = id; 51 | } 52 | 53 | TFIMP(void, StampedTransform, setData)(tf::StampedTransform *self, tf::Transform *input) { 54 | self->setData(*input); 55 | } 56 | 57 | TFIMP(bool, StampedTransform, eq)(tf::StampedTransform *self, tf::StampedTransform *other) { 58 | return *self == *other; 59 | } 60 | 61 | TFIMP(void, StampedTransform,toStampedTransformMsg)(tf::StampedTransform *self, THByteStorage *output) 62 | { 63 | geometry_msgs::TransformStamped msg; 64 | tf::transformStampedTFToMsg(*self, msg); 65 | 66 | uint32_t length = ros::serialization::serializationLength(msg); 67 | THByteStorage_resize(output, length + sizeof(uint32_t)); 68 | ros::serialization::OStream stream(THByteStorage_data(output), length + sizeof(uint32_t)); 69 | stream.next((uint32_t)length); 70 | ros::serialization::serialize(stream, msg); 71 | } 72 | 73 | TFIMP(void, StampedTransform,toStampedPoseMsg)(tf::StampedTransform *self, THByteStorage *output) 74 | { 75 | const tf::Pose tf_(self->getRotation(), self->getOrigin()); 76 | 77 | geometry_msgs::PoseStamped msg; 78 | tf::Stamped pose (tf_, self->stamp_,self->frame_id_); 79 | tf::poseStampedTFToMsg(pose, msg); 80 | 81 | uint32_t length = ros::serialization::serializationLength(msg); 82 | THByteStorage_resize(output, length + sizeof(uint32_t)); 83 | ros::serialization::OStream stream(THByteStorage_data(output), length + sizeof(uint32_t)); 84 | stream.next((uint32_t)length); 85 | ros::serialization::serialize(stream, msg); 86 | } -------------------------------------------------------------------------------- /src/ros/message_buffer.h: -------------------------------------------------------------------------------- 1 | #ifndef message_buffer_h 2 | #define message_buffer_h 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "raw_message.h" 9 | 10 | class MessageBuffer 11 | : public ros::SubscriptionCallbackHelper { 12 | public: 13 | MessageBuffer(int max_backlog = -1) 14 | : max_backlog(max_backlog) { 15 | } 16 | 17 | virtual ros::VoidConstPtr deserialize(const ros::SubscriptionCallbackHelperDeserializeParams ¶ms) { 18 | 19 | // create buffer and copy message bytes 20 | boost::shared_ptr buffer(new RawMessage(params.length + sizeof(uint32_t))); 21 | ros::serialization::OStream stream(buffer->get_OStream()); 22 | stream.next((uint32_t)params.length); 23 | memcpy(buffer->get() + sizeof(uint32_t), params.buffer, params.length); 24 | 25 | if (params.connection_header != NULL) { 26 | buffer->set_header(*params.connection_header); 27 | } 28 | 29 | // lock queue mutex and add new message buffer to queue 30 | { 31 | boost::unique_lock lock(queue_lock); 32 | if (max_backlog >= 0 && message_queue.size() >= static_cast(max_backlog)) 33 | message_queue.pop_front(); // remove oldest message 34 | message_queue.push_back(buffer); 35 | } 36 | 37 | // notify potentially waiting thread 38 | message_available.notify_one(); 39 | return ros::VoidConstPtr(); 40 | } 41 | 42 | virtual void call(ros::SubscriptionCallbackHelperCallParams ¶ms) { 43 | } 44 | 45 | virtual const std::type_info& getTypeInfo() { 46 | return typeid(void); 47 | } 48 | 49 | virtual bool isConst() { 50 | return false; 51 | } 52 | 53 | virtual bool hasHeader() { 54 | return false; 55 | } 56 | 57 | boost::shared_ptr read(int timeout_milliseconds) { 58 | boost::unique_lock lock(queue_lock); 59 | 60 | boost::chrono::system_clock::time_point timeout = boost::chrono::system_clock::now() 61 | + boost::chrono::milliseconds(timeout_milliseconds); 62 | 63 | do { 64 | // check if messages are available 65 | if (!message_queue.empty()) { 66 | boost::shared_ptr msg = message_queue.front(); 67 | message_queue.pop_front(); 68 | return msg; 69 | } 70 | 71 | // wait with timeout for frame to be captured 72 | } while (timeout_milliseconds != 0 73 | && (timeout_milliseconds < 0 || message_available.wait_until(lock, timeout) != boost::cv_status::timeout) 74 | ); 75 | 76 | return boost::shared_ptr(); 77 | } 78 | 79 | size_t count() { 80 | boost::unique_lock lock(queue_lock); 81 | return message_queue.size(); 82 | } 83 | 84 | void clear() { 85 | boost::unique_lock lock(queue_lock); 86 | message_queue.clear(); 87 | } 88 | 89 | private: 90 | int max_backlog; // unlimited if maxBacklog < 0 91 | boost::mutex queue_lock; 92 | boost::condition_variable message_available; 93 | std::deque > message_queue; 94 | }; 95 | 96 | #endif // message_buffer_h 97 | -------------------------------------------------------------------------------- /lua/tf/Transform.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local tf = ros.tf 6 | 7 | local Transform = torch.class('tf.Transform', tf) 8 | 9 | function init() 10 | local Transform_method_names = { 11 | "new", 12 | "clone", 13 | "delete", 14 | "setIdentity", 15 | "mul_Quaternion", 16 | "mul_Transform", 17 | "inverse", 18 | "getBasis", 19 | "getOrigin", 20 | "setRotation", 21 | "getRotation", 22 | } 23 | 24 | return utils.create_method_table("tf_Transform_", Transform_method_names ) 25 | end 26 | 27 | local f = init() 28 | 29 | function Transform:__init() 30 | self.o = f.new() 31 | self:setIdentity() 32 | end 33 | 34 | function Transform:cdata() 35 | return self.o 36 | end 37 | 38 | function Transform:clone() 39 | local c = torch.factory('tf.Transform')() 40 | rawset(c, 'o', f.clone(self.o)) 41 | return c 42 | end 43 | 44 | function Transform.fromStamped(st) 45 | local c = torch.factory('tf.Transform')() 46 | rawset(c, 'o', f.clone(st.o)) 47 | return c 48 | end 49 | 50 | function Transform:setIdentity() 51 | f.setIdentity(self.o) 52 | return self 53 | end 54 | 55 | function Transform:getBasis(basis) 56 | basis = basis or torch.DoubleTensor() 57 | f.getBasis(self.o, basis:cdata()) 58 | return basis 59 | end 60 | 61 | function Transform:setBasis(basis) 62 | self:getBasis()[{}] = basis 63 | return self 64 | end 65 | 66 | function Transform:mul(t, output) 67 | output = output or tf.Transform() 68 | if torch.isTypeOf(t, tf.Transform) then 69 | f.mul_Transform(self.o, t:cdata(), output:cdata()) 70 | elseif torch.isTypeOf(t, tf.Quaternion) then 71 | f.mul_Quaternion(self.o, t:cdata(), output:cdata()) 72 | else 73 | error('tf.Transform or tf.Quaternion expected') 74 | end 75 | return output 76 | end 77 | 78 | function Transform:getOrigin(origin) 79 | origin = origin or torch.DoubleTensor() 80 | f.getOrigin(self.o, origin:cdata()) 81 | return origin 82 | end 83 | 84 | function Transform:setOrigin(origin) 85 | if not torch.isTensor(origin) then 86 | origin = torch.DoubleTensor(origin) 87 | end 88 | self:getOrigin()[{}] = origin 89 | return self 90 | end 91 | 92 | function Transform:inverse(output) 93 | output = output or tf.Transform() 94 | f.inverse(self.o, output:cdata()) 95 | return output 96 | end 97 | 98 | function Transform:toTensor() 99 | local t = torch.zeros(4,4) 100 | t[{{1,3},{1,3}}] = self:getBasis() 101 | t[{{1,3},{4}}] = self:getOrigin() 102 | t[{4,4}] = 1 103 | return t 104 | end 105 | 106 | function Transform:fromTensor(t) 107 | self:getBasis()[{}] = t[{{1,3},{1,3}}] 108 | self:getOrigin()[{}] = t[{{1,3},{4}}] 109 | return self 110 | end 111 | 112 | function Transform:getRotation(output) 113 | output = output or tf.Quaternion() 114 | f.getRotation(self.o, output:cdata()) 115 | return output 116 | end 117 | 118 | function Transform:setRotation(quaternion) 119 | f.setRotation(self.o, quaternion:cdata()) 120 | return self 121 | end 122 | 123 | function Transform:__tostring() 124 | local t = self:toTensor() 125 | local s = '' 126 | for i=1,4 do 127 | s = s .. string.format('%9g %9g %9g %9g\n', t[{i,1}], t[{i,2}], t[{i,3}], t[{i,4}]) 128 | end 129 | return s 130 | end 131 | -------------------------------------------------------------------------------- /lua/actionlib/ActionSpec.lua: -------------------------------------------------------------------------------- 1 | local md5 = require 'md5' 2 | local path = require 'pl.path' 3 | local torch = require 'torch' 4 | local ros = require 'ros.env' 5 | local actionlib = ros.actionlib 6 | 7 | 8 | local ActionSpec = torch.class('ros.actionlib.ActionSpec', actionlib) 9 | local DEFAULT_PACKAGE = 'actionlib' 10 | 11 | 12 | --- (internal) load from iterator 13 | -- @param iterator iterator that returns one line of the specification at a time 14 | local function load_from_iterator(self, iterator) 15 | local goal, result, feedback = {}, {}, {} 16 | 17 | -- extract goal, result and feedback messages from action descriptions 18 | local t = goal 19 | for line in iterator do 20 | if string.find(line, '^%s*---%s*$') ~= nil then 21 | if t == goal then 22 | t = result 23 | else 24 | t = feedback 25 | end 26 | else 27 | table.insert(t, line) 28 | end 29 | end 30 | 31 | -- generate inner structures 32 | self.goal_spec = ros.get_msgspec(self.type .. 'Goal', table.concat(goal, '\n')) 33 | self.result_spec = ros.get_msgspec(self.type .. 'Result', table.concat(result, '\n')) 34 | self.feedback_spec = ros.get_msgspec(self.type .. 'Feedback', table.concat(feedback, '\n')) 35 | 36 | -- generate derived 37 | local action_goal_msg = 'Header header\nactionlib_msgs/GoalID goal_id\n' .. self.type .. 'Goal goal' 38 | local action_result_msg = 'Header header\nactionlib_msgs/GoalStatus status\n' .. self.type .. 'Result result' 39 | local action_feedback_msg = 'Header header\nactionlib_msgs/GoalStatus status\n' .. self.type .. 'Feedback feedback' 40 | 41 | self.action_goal_spec = ros.get_msgspec(self.type .. 'ActionGoal', action_goal_msg) 42 | self.action_result_spec = ros.get_msgspec(self.type .. 'ActionResult', action_result_msg) 43 | self.action_feedback_spec = ros.get_msgspec(self.type .. 'ActionFeedback', action_feedback_msg) 44 | end 45 | 46 | 47 | local function load_from_action_file(self) 48 | local package_path = ros.find_package(self.package) 49 | local tmp_path = path.join(package_path, 'action') 50 | self.file = path.join(tmp_path, self.short_type .. '.action') 51 | return load_from_iterator(self, io.lines(self.file)) 52 | end 53 | 54 | 55 | --- (internal) Load specification from string. 56 | -- @param s string containing the message specification 57 | local function load_from_string(self, s) 58 | return load_from_iterator(self, s:gmatch('([^\r\n]+)\n?')) 59 | end 60 | 61 | 62 | function ActionSpec:__init(type, specstr) 63 | assert(type, 'Action type is expected') 64 | self.type = type 65 | 66 | local slashpos = type:find('/') 67 | if slashpos then 68 | self.package = type:sub(1, slashpos - 1) 69 | self.short_type = type:sub(slashpos + 1) 70 | else 71 | self.package = DEFAULT_PACKAGE 72 | self.short_type = type 73 | end 74 | 75 | if specstr then 76 | load_from_string(self, specstr) 77 | else 78 | load_from_action_file(self) 79 | end 80 | end 81 | 82 | 83 | function ActionSpec:format_spec(ln) 84 | table.insert(ln, 'Action ' .. self.type) 85 | self.action_goal_spec:format_spec(ln) 86 | table.insert(ln, '---') 87 | self.action_result_spec:format_spec(ln) 88 | table.insert(ln, '---') 89 | self.action_feedback_spec:format_spec(ln) 90 | return ln 91 | end 92 | 93 | 94 | function ActionSpec:__tostring() 95 | lines = self:format_spec({}) 96 | table.insert(lines, '') 97 | return table.concat(lines, '\n') 98 | end 99 | -------------------------------------------------------------------------------- /lua/std/StringVector.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local StringVector = torch.class('std.StringVector', std) 8 | 9 | function init() 10 | local StringVector_method_names = { 11 | "new", 12 | "clone", 13 | "delete", 14 | "size", 15 | "getAt", 16 | "setAt", 17 | "push_back", 18 | "pop_back", 19 | "clear", 20 | "insert", 21 | "erase", 22 | "empty" 23 | } 24 | 25 | return utils.create_method_table("std_StringVector_", StringVector_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | function StringVector:__init(...) 31 | rawset(self, 'o', f.new()) 32 | if select("#", ...) > 0 then 33 | local x = ... 34 | if type(x) ~= 'table' then 35 | x = { ... } 36 | end 37 | self:insertFromTable(1, x) 38 | end 39 | end 40 | 41 | function StringVector:cdata() 42 | return self.o 43 | end 44 | 45 | function StringVector:clone() 46 | local c = torch.factory('std.StringVector')() 47 | rawset(c, 'o', f.clone(self.o)) 48 | return c 49 | end 50 | 51 | function StringVector:size() 52 | return f.size(self.o) 53 | end 54 | 55 | function StringVector:__len() 56 | return self:size() 57 | end 58 | 59 | function StringVector:__index(idx) 60 | local v = rawget(self, idx) 61 | if not v then 62 | v = StringVector[idx] 63 | if not v and type(idx) == 'number' then 64 | local o = rawget(self, 'o') 65 | v = ffi.string(f.getAt(o, idx-1)) 66 | end 67 | end 68 | return v 69 | end 70 | 71 | function StringVector:__newindex(idx, v) 72 | local o = rawget(self, 'o') 73 | if type(idx) == 'number' then 74 | f.setAt(o, idx-1, tostring(v)) 75 | else 76 | rawset(self, idx, v) 77 | end 78 | end 79 | 80 | function StringVector:push_back(value) 81 | f.push_back(self.o, tostring(value)) 82 | end 83 | 84 | function StringVector:pop_back() 85 | local last = self[#self] 86 | f.pop_back(self.o) 87 | return last 88 | end 89 | 90 | function StringVector:clear() 91 | f.clear(self.o) 92 | end 93 | 94 | function StringVector:insert(pos, value, n) 95 | if pos < 1 then 96 | pos = 1 97 | elseif pos > #self+1 then 98 | pos = #self + 1 99 | end 100 | f.insert(self.o, pos-1, n or 1, value) 101 | end 102 | 103 | function StringVector:insertFromTable(pos, t) 104 | if type(pos) == 'table' then 105 | t = pos 106 | pos = #self + 1 107 | end 108 | pos = pos or #self + 1 109 | for _,v in pairs(t) do 110 | self:insert(pos, v) 111 | pos = pos + 1 112 | end 113 | end 114 | 115 | function StringVector:erase(begin_pos, end_pos) 116 | f.erase(self.o, begin_pos-1, (end_pos or begin_pos + 1)-1) 117 | end 118 | 119 | function StringVector:__pairs() 120 | return function (t, k) 121 | local i = k + 1 122 | if i > #t then 123 | return nil 124 | else 125 | local v = t[i] 126 | return i, v 127 | end 128 | end, self, 0 129 | end 130 | 131 | function StringVector:__ipairs() 132 | return self:__pairs() 133 | end 134 | 135 | function StringVector:totable() 136 | local t = {} 137 | for i,v in ipairs(self) do 138 | table.insert(t, v) 139 | end 140 | return t 141 | end 142 | 143 | function StringVector:__tostring() 144 | local t = self:totable() 145 | return table.concat(t, '\n') 146 | end 147 | -------------------------------------------------------------------------------- /lua/std/Task.lua: -------------------------------------------------------------------------------- 1 | local torch = require 'torch' 2 | local ros = require 'ros.env' 3 | local utils = require 'ros.utils' 4 | local std = ros.std 5 | 6 | local Task = torch.class('std.Task', std) 7 | 8 | local TaskState = { 9 | NotStarted = 1, 10 | Running = 2, 11 | Succeeded = 3, 12 | Failed = 4, 13 | Cancelled = 5, 14 | 'NOT_STARTED', 'RUNNING', 'SUCCEEDED', 'FAILED', 'CANCELLED' 15 | } 16 | std.TaskState = TaskState 17 | 18 | 19 | function Task.create(start_handler, cancel_handler, completion_handler, auto_start) 20 | local task = Task.new(start_handler, cancel_handler) 21 | 22 | if completion_handler ~= nil then 23 | task:addCompletionHandler(completion_handler) 24 | end 25 | 26 | if auto_start == true then 27 | task:start() 28 | end 29 | 30 | return task 31 | end 32 | 33 | 34 | function Task:__init(start_handler, cancel_handler) 35 | self.start_handler = start_handler 36 | self.cancel_handler = cancel_handler 37 | self.state = TaskState.NotStarted 38 | self.completion_handlers = {} 39 | end 40 | 41 | 42 | function Task:getState() 43 | return self.state 44 | end 45 | 46 | 47 | function Task:addCompletionHandler(completion_handler) 48 | self.completion_handlers[#self.completion_handlers + 1] = completion_handler 49 | end 50 | 51 | 52 | function Task:removeCompletionHandler(completion_handler) 53 | local i = utils.indexOf(self.completion_handlers, completion_handler) 54 | if i ~= -1 then 55 | table.remove(self.completion_handlers, i) 56 | end 57 | end 58 | 59 | 60 | function Task:start() 61 | if self.state ~= TaskState.NotStarted then 62 | error(string.format('Task already started. Current state is: \'%s\'', TaskState[self.state])) 63 | end 64 | 65 | self.state = TaskState.Running 66 | self:start_handler() 67 | end 68 | 69 | 70 | function Task:cancel(reason) 71 | if self.cancel_handler ~= nil then 72 | self:cancel_handler(reason) 73 | end 74 | end 75 | 76 | 77 | function Task:hasCompleted() 78 | return self.state == TaskState.Succeeded or self.state == TaskState.Failed or self.state == TaskState.Cancelled 79 | end 80 | 81 | 82 | function Task:hasCompletedSuccessfully() 83 | return self.state == TaskState.Succeeded 84 | end 85 | 86 | 87 | function Task:waitForCompletion(timeout_in_ms, spin_rate, spin_function) 88 | spin_rate = spin_rate or 25 89 | if not torch.isTypeOf(spin_rate, ros.Rate) then 90 | spin_rate = ros.Rate(spin_rate) 91 | end 92 | 93 | spin_function = spin_function or ros.spinOnce 94 | local start_time = ros.Time.now() 95 | 96 | while not self:hasCompleted() do 97 | if timeout_in_ms ~= nil and (ros.Time.now() - start_time):toSec() > timeout_in_ms / 1000 then 98 | return false 99 | end 100 | spin_rate:sleep() 101 | spin_function() 102 | end 103 | 104 | return true 105 | end 106 | 107 | 108 | function Task:getResult() 109 | if not self:hasCompleted() then 110 | self:waitForCompletion() 111 | end 112 | 113 | return self.result 114 | end 115 | 116 | 117 | function Task:setResult(terminal_state, value) 118 | if self:hasCompleted() then 119 | error('Task already completed.') 120 | end 121 | 122 | self.state = terminal_state 123 | self.result = value 124 | assert(self:hasCompleted(), 'Invalid terminal state specified.') 125 | 126 | for i, handler in ipairs(self.completion_handlers) do 127 | handler(self) 128 | end 129 | end 130 | -------------------------------------------------------------------------------- /src/tf/quaternion.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-tf.h" 2 | 3 | TFIMP(tf::Quaternion *, Quaternion, new)() { 4 | return new tf::Quaternion(); 5 | } 6 | 7 | TFIMP(tf::Quaternion *, Quaternion, clone)(tf::Quaternion *self) { 8 | return new tf::Quaternion(*self); 9 | } 10 | 11 | TFIMP(void, Quaternion, delete)(tf::Quaternion *self) { 12 | delete self; 13 | } 14 | 15 | TFIMP(void, Quaternion, setIdentity)(tf::Quaternion *self) { 16 | *self = tf::Quaternion::getIdentity(); 17 | } 18 | 19 | TFIMP(void, Quaternion, setRotation_Tensor)(tf::Quaternion *self, THDoubleTensor *axis, double angle) { 20 | tf::Vector3 v; 21 | copyTensorToVector3(axis, v); 22 | self->setRotation(v, angle); 23 | } 24 | 25 | TFIMP(void, Quaternion, setEuler)(tf::Quaternion *self, double yaw, double pitch, double roll) { 26 | self->setEuler(yaw, pitch, roll); 27 | } 28 | 29 | TFIMP(void, Quaternion, getRPY)(tf::Quaternion *self, int solution_number, THDoubleTensor *result) { 30 | double roll = 0, pitch = 0, yaw = 0; 31 | tf::Matrix3x3(*self).getRPY(roll, pitch, yaw, solution_number); 32 | copyVector3ToTensor(tf::Vector3(roll, pitch, yaw), result); 33 | } 34 | 35 | TFIMP(void, Quaternion, setRPY)(tf::Quaternion *self, double roll, double pitch, double yaw) { 36 | self->setRPY(roll, pitch, yaw); 37 | } 38 | 39 | TFIMP(double, Quaternion, getAngle)(tf::Quaternion *self) { 40 | return self->getAngle(); 41 | } 42 | 43 | TFIMP(void, Quaternion, getAxis_Tensor)(tf::Quaternion *self, THDoubleTensor *axis) { 44 | const tf::Vector3& a = self->getAxis(); 45 | copyVector3ToTensor(a, axis); 46 | } 47 | 48 | TFIMP(void, Quaternion, inverse)(tf::Quaternion *self, tf::Quaternion *result) { 49 | *result = self->inverse(); 50 | } 51 | 52 | TFIMP(double, Quaternion, length2)(tf::Quaternion *self) { 53 | return self->length2(); 54 | } 55 | 56 | TFIMP(void, Quaternion, normalize)(tf::Quaternion *self) { 57 | self->normalize(); 58 | } 59 | 60 | TFIMP(double, Quaternion, angle)(tf::Quaternion *self, tf::Quaternion *other) { 61 | return self->angle(*other); 62 | } 63 | 64 | TFIMP(double, Quaternion, angleShortestPath)(tf::Quaternion *self, tf::Quaternion *other) { 65 | return self->angleShortestPath(*other); 66 | } 67 | 68 | TFIMP(void, Quaternion, add)(tf::Quaternion *self, tf::Quaternion *other, tf::Quaternion *result) { 69 | *result = self->operator+(*other); 70 | } 71 | 72 | TFIMP(void, Quaternion, sub)(tf::Quaternion *self, tf::Quaternion *other, tf::Quaternion *result) { 73 | *result = self->operator-(*other); 74 | } 75 | 76 | TFIMP(void, Quaternion, mul)(tf::Quaternion *self, tf::Quaternion *other, tf::Quaternion *result) { 77 | if (result != self) 78 | *result = *self; 79 | *result = result->operator*=(*other); 80 | } 81 | 82 | TFIMP(void, Quaternion, mul_scalar)(tf::Quaternion *self, double factor, tf::Quaternion *result) { 83 | *result = self->operator*(factor); 84 | } 85 | 86 | TFIMP(void, Quaternion, div_scalar)(tf::Quaternion *self, double divisor, tf::Quaternion *result) { 87 | *result = self->operator/(divisor); 88 | } 89 | 90 | TFIMP(double, Quaternion, dot)(tf::Quaternion *self, tf::Quaternion *other) { 91 | return self->dot(*other); 92 | } 93 | 94 | TFIMP(void, Quaternion, slerp)(tf::Quaternion *self, tf::Quaternion *other, double t, tf::Quaternion *result) { 95 | *result = self->slerp(*other, t); 96 | } 97 | 98 | TFIMP(void, Quaternion, viewTensor)(tf::Quaternion *self, THDoubleTensor *result) { 99 | viewQuaternion(*self, result); 100 | } 101 | -------------------------------------------------------------------------------- /lua/ros.lua: -------------------------------------------------------------------------------- 1 | --- ROS main class 2 | -- @classmod ros 3 | local ffi = require 'ffi' 4 | local torch = require 'torch' 5 | local ros = require 'ros.env' 6 | local utils = require 'ros.utils' 7 | local std = ros.std 8 | 9 | function init() 10 | local ros_method_names = { 11 | "init", 12 | "shutdown", 13 | "spinOnce", 14 | "requestShutdown", 15 | "isInitialized", 16 | "isStarted", 17 | "isShuttingDown", 18 | "ok", 19 | "waitForShutdown" 20 | } 21 | 22 | local f = utils.create_method_table("ros___", ros_method_names) 23 | 24 | for n,v in pairs(f) do 25 | ros[n] = v 26 | end 27 | 28 | return f 29 | end 30 | 31 | local f = init() 32 | 33 | ros.init_options = { 34 | NoSigintHandler = 1, 35 | AnonymousName = 2, 36 | NoRosout = 4 37 | } 38 | 39 | --- Initialize ROS 40 | -- @tparam[opt=torch_ros] string name Name 41 | -- @param[opt] int options Options 42 | -- @param[opt] tab remappings Table with with key->value (local_name->external_name) remappings or list with command line args of which entries containing ':=' will be passed as remappings to ros::init(). 43 | function ros.init(name, options, remappings) 44 | remappings = remappings or {} 45 | assert(type(remappings) == "table", "Argument 'remappings' must be of type table. See https://github.com/torch/torch7/blob/master/doc/cmdline.md") 46 | 47 | local remap = std.StringMap() 48 | for k,v in pairs(remappings) do 49 | if type(k) == 'string' and type(v) == 'string' then 50 | remap[k] = v 51 | elseif type(k) == 'number' then 52 | local local_name, external_name = string.match(v, '(.*):=(.*)') 53 | if local_name and external_name then 54 | ros.DEBUG("remap: %s => %s", local_name, external_name) 55 | remap[local_name] = external_name 56 | end 57 | end 58 | end 59 | 60 | if not name then 61 | name = 'torch_ros' 62 | options = ros.init_options.AnonymousName 63 | end 64 | f.init(remap:cdata(), name, options or 0) 65 | end 66 | 67 | --- Will call all the callbacks waiting to be called at the moment. 68 | -- @tparam[opt=true] ros.Duration. Time to wait for callback. 69 | -- @tparam[opt=true] bool no_default_callbacks If true, the callbacks waiting in the default callback queue will not be called 70 | function ros.spinOnce(timeout, no_default_callbacks) 71 | if torch.type(timeout) == 'boolean' then 72 | no_default_callbacks = timeout 73 | timeout = nil 74 | elseif torch.type(timeout) == 'number' then 75 | timeout = ros.Duration(timeout) 76 | end 77 | f.spinOnce() 78 | 79 | if not no_default_callbacks then 80 | -- process pending callbacks on default queue 81 | local queue = ros.DEFAULT_CALLBACK_QUEUE 82 | if queue ~= nil and ros.ok() then 83 | queue:callAvailable(timeout) 84 | end 85 | end 86 | end 87 | 88 | --- Register a callback. 89 | -- @tparam func fn Callback function 90 | -- @tparam int round Callbacks with a higher round integer are called after all callbacks with lower round numbers have been called. 91 | function ros.registerSpinCallback(fn, round) 92 | ros.DEFAULT_CALLBACK_QUEUE:registerSpinCallback(fn, round) 93 | end 94 | 95 | --- Unregister a callback 96 | -- @tparam func fn Callback function 97 | -- @tparam int round Callbacks with a higher round integer are called after all callbacks with lower round numbers have been called. 98 | function ros.unregisterSpinCallback(fn, round) 99 | ros.DEFAULT_CALLBACK_QUEUE:unregisterSpinCallback(fn, round) 100 | end 101 | 102 | return ros 103 | -------------------------------------------------------------------------------- /lua/SrvSpec.lua: -------------------------------------------------------------------------------- 1 | local md5 = require 'md5' 2 | local path = require 'pl.path' 3 | local torch = require 'torch' 4 | local ros = require 'ros.env' 5 | 6 | local SrvSpec = torch.class('ros.SrvSpec', ros) 7 | 8 | local srvspec_cache = {} 9 | local DEFAULT_PACKAGE = 'roslib' 10 | 11 | --- Get service specification. 12 | -- @param srv_type service type (e.g. roscpp/SetLoggerLevel). The name must include 13 | -- the package. 14 | local function get_srvspec(srv_type, specstr) 15 | if not srvspec_cache[srv_type] then 16 | srvspec_cache[srv_type] = ros.SrvSpec(srv_type, specstr) 17 | end 18 | 19 | return srvspec_cache[srv_type] 20 | end 21 | ros.get_srvspec = get_srvspec 22 | 23 | --- (internal) load from iterator 24 | -- @param iterator iterator that returns one line of the specification at a time 25 | local function load_from_iterator(self, iterator) 26 | local request, response = {}, {} 27 | 28 | -- extract the request and response message descriptions 29 | local t = request 30 | for line in iterator do 31 | if string.find(line, '^%s*---%s*$') ~= nil then 32 | t = response 33 | else 34 | table.insert(t, line) 35 | end 36 | end 37 | 38 | self.request_spec = ros.MsgSpec(self.type .. '_Request', table.concat(request, '\n')) 39 | self.response_spec = ros.MsgSpec(self.type .. '_Response', table.concat(response, '\n')) 40 | end 41 | 42 | local function load_srvspec(self) 43 | local package_path = ros.find_package(self.package) 44 | local tmp_path = path.join(package_path, 'srv') 45 | self.file = path.join(tmp_path, self.short_type .. ".srv") 46 | return load_from_iterator(self, io.lines(self.file)) 47 | end 48 | 49 | --- (internal) Load specification from string. 50 | -- @param s string containing the message specification 51 | local function load_from_string(self, s) 52 | return load_from_iterator(self, s:gmatch('([^\r\n]+)\n?')) 53 | end 54 | 55 | -- (internal) Calculate MD5 sum. 56 | -- Generates the MD5 sum for this message type. 57 | -- @return MD5 sum as text 58 | local function calc_md5(self) 59 | local s = self.request_spec:generate_hashtext() .. self.response_spec:generate_hashtext() 60 | self.md5sum = md5.sumhexa(s) 61 | return self.md5sum 62 | end 63 | 64 | function SrvSpec:__init(type, specstr) 65 | assert(type, 'Service type is expected') 66 | self.type = type 67 | 68 | local slashpos = type:find('/') 69 | if slashpos then 70 | self.package = type:sub(1, slashpos - 1) 71 | self.short_type = type:sub(slashpos + 1) 72 | else 73 | self.package = DEFAULT_PACKAGE 74 | self.short_type = type 75 | end 76 | 77 | if specstr then 78 | load_from_string(self, specstr) 79 | else 80 | load_srvspec(self) 81 | end 82 | end 83 | 84 | --- Get MD5 sum of type specification. 85 | -- This will create a text representation of the service specification and 86 | -- generate the MD5 sum for it. The value is cached so concurrent calls will 87 | -- cause the cached value to be returned 88 | -- @return MD5 sum of message specification 89 | function SrvSpec:md5() 90 | return self.md5sum or calc_md5(self) 91 | end 92 | 93 | function SrvSpec:format_spec(ln) 94 | table.insert(ln, 'Service ' .. self.type) 95 | table.insert(ln, 'MD5: ' .. self:md5()) 96 | self.request_spec:format_spec(ln) 97 | table.insert(ln, '---') 98 | self.response_spec:format_spec(ln) 99 | return ln 100 | end 101 | 102 | function SrvSpec:__tostring() 103 | lines = self:format_spec({}) 104 | table.insert(lines, '') 105 | return table.concat(lines, '\n') 106 | end 107 | -------------------------------------------------------------------------------- /lua/Subscriber.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local tf = ros.tf 6 | 7 | local Subscriber = torch.class('ros.Subscriber', ros) 8 | local Subscriber_ptr_ct = ffi.typeof('ros_Subscriber *') 9 | 10 | function init() 11 | local Subscriber_method_names = { 12 | "clone", 13 | "delete", 14 | "shutdown", 15 | "getTopic", 16 | "getNumPublishers" 17 | } 18 | 19 | return utils.create_method_table("ros_Subscriber_", Subscriber_method_names) 20 | end 21 | 22 | local f = init() 23 | 24 | function Subscriber:__init(ptr, buffer, msg_spec, callback_queue, serialization_handlers) 25 | if not ptr or not ffi.typeof(ptr) == Subscriber_ptr_ct then 26 | error('argument 1: ros.Subscriber * expected.') 27 | end 28 | self.o = ptr 29 | self.buffer = buffer 30 | self.msg_spec = msg_spec 31 | ffi.gc(ptr, f.delete) 32 | self.callback_queue = callback_queue or ros.DEFAULT_CALLBACK_QUEUE 33 | self.callbacks = {} 34 | self.serialization_handlers = serialization_handlers 35 | end 36 | 37 | function Subscriber:cdata() 38 | return self.o 39 | end 40 | 41 | function Subscriber:clone() 42 | local c = torch.factory('ros.Subscriber')() 43 | rawset(c, 'o', f.clone(self.o)) 44 | return c 45 | end 46 | 47 | function Subscriber:shutdown() 48 | if self.spin_callback_function ~= nil then 49 | ros.unregisterSpinCallback(self.spin_callback_function) 50 | self.spin_callback_function = nil 51 | end 52 | 53 | f.shutdown(self.o) 54 | end 55 | 56 | function Subscriber:getTopic() 57 | local s = std.String() 58 | f.getTopic(self.o, s:cdata()) 59 | return s:get() 60 | end 61 | 62 | function Subscriber:getNumPublishers() 63 | return f.getNumPublishers(self.o) 64 | end 65 | 66 | function Subscriber:hasMessage() 67 | return self:getMessageCount() > 0 68 | end 69 | 70 | function Subscriber:getMessageCount() 71 | return self.buffer:getCount() 72 | end 73 | 74 | function Subscriber:read(timeout_milliseconds, result) 75 | local msg_bytes, msg_header = self.buffer:read(timeout_milliseconds) 76 | local msg 77 | if msg_bytes then 78 | local sr = ros.StorageReader(msg_bytes, 0, nil, nil, self.serialization_handlers) 79 | 80 | local handler = sr:getHandler(self.msg_spec.type) 81 | if handler ~= nil then 82 | local totalLength = sr:readUInt32() 83 | msg = handler:read(sr) 84 | else 85 | msg = result or ros.Message(self.msg_spec, true) 86 | msg:deserialize(sr) 87 | end 88 | end 89 | return msg, msg_header 90 | end 91 | 92 | function Subscriber:triggerCallbacks() 93 | local cbs 94 | local count = self:getMessageCount() 95 | while count > 0 do 96 | count=count-1 97 | local msg, header = self:read(0) 98 | if msg ~= nil then 99 | cbs = cbs or utils.getTableKeys(self.callbacks) -- lazy isolation copy of callbacks 100 | for _,f in ipairs(cbs) do 101 | f(msg, header, self) 102 | end 103 | end 104 | end 105 | end 106 | 107 | function Subscriber:registerCallback(message_cb) 108 | self.callbacks[message_cb] = true -- table used as set 109 | if self.spin_callback_function == nil then 110 | self.spin_callback_function = function() self:triggerCallbacks() end 111 | self.spin_callback_id = self.callback_queue:registerSpinCallback(self.spin_callback_function) 112 | end 113 | end 114 | 115 | function Subscriber:unregisterCallback(message_cb) 116 | self.callbacks[message_cb] = nil 117 | if self.spin_callback_function ~= nil and next(self.callbacks) == nil then 118 | self.callback_queue:unregisterSpinCallback(self.spin_callback_function) 119 | self.spin_callback_function = nil 120 | end 121 | end 122 | -------------------------------------------------------------------------------- /lua/std/VariableTable.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | std = ros.std 6 | 7 | local VariableTable = torch.class('std.VariableTable', std) 8 | local TYPE_CODE = std.Variable.TYPE_CODE 9 | 10 | function init() 11 | local VariableTable_method_names = { 12 | 'new', 13 | 'delete', 14 | 'clone', 15 | 'size', 16 | 'clear', 17 | 'getField', 18 | 'setField', 19 | 'erase', 20 | 'exists', 21 | 'keys', 22 | 'values' 23 | } 24 | 25 | return utils.create_method_table("std_VariableTable_", VariableTable_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | function VariableTable:__init(x) 31 | rawset(self, 'o', f.new()) 32 | if x ~= nil and type(x) == 'table' then 33 | self:insertFromTable(x) 34 | end 35 | end 36 | 37 | function VariableTable:cdata() 38 | return rawget(self, 'o') 39 | end 40 | 41 | function VariableTable:clone() 42 | local c = torch.factory('std.VariableTable')() 43 | rawset(c, 'o', f.clone(self.o)) 44 | return c 45 | end 46 | 47 | function VariableTable:size() 48 | return f.size(self.o) 49 | end 50 | 51 | function VariableTable:__len() 52 | return self:size() 53 | end 54 | 55 | function VariableTable:erase(key) 56 | return f.erase(self.o, key) 57 | end 58 | 59 | function VariableTable:insertFromTable(t) 60 | for k,v in pairs(t) do 61 | if not torch.isTypeOf(v, std.Variable) then 62 | v = std.Variable(v) 63 | end 64 | self:setField(k, v) 65 | end 66 | end 67 | 68 | function VariableTable:getField(key) 69 | local o = rawget(self, 'o') 70 | local v = std.Variable() 71 | if not f.getField(self.o, key, v:cdata()) then 72 | return nil 73 | end 74 | local t = v:get_type() 75 | if t == TYPE_CODE.vector or t == TYPE_CODE.table then 76 | v = v:get() 77 | end 78 | return v 79 | end 80 | 81 | function VariableTable:setField(key, value) 82 | local o = rawget(self, 'o') 83 | if not torch.isTypeOf(value, std.Variable) then 84 | value = std.Variable(value) 85 | end 86 | f.setField(o, key, value:cdata()) 87 | end 88 | 89 | function VariableTable:__index(key) 90 | local v = rawget(self, key) 91 | if not v then 92 | v = VariableTable[key] 93 | if not v and type(key) == 'string' then 94 | v = self:getField(key) 95 | end 96 | end 97 | return v 98 | end 99 | 100 | function VariableTable:__newindex(key, value) 101 | local o = rawget(self, 'o') 102 | if type(key) == 'string' then 103 | self:setField(key, value) 104 | else 105 | rawset(self, key, value) 106 | end 107 | end 108 | 109 | function VariableTable:exists(key) 110 | return f.exists(self.o, key) 111 | end 112 | 113 | function VariableTable:keys() 114 | local v = std.StringVector() 115 | f.keys(self.o, v:cdata()) 116 | return v 117 | end 118 | 119 | function VariableTable:values() 120 | local v = std.VariableVector() 121 | f.values(self.o, v:cdata()) 122 | return v 123 | end 124 | 125 | function VariableTable:totable() 126 | local k,v = self:keys(),self:values() 127 | local r = {} 128 | for i=1,#k do 129 | local x = v[i]:get() 130 | if torch.isTypeOf(x, std.VariableTable) then 131 | x = x:totable() 132 | elseif torch.isTypeOf(x, std.VariableVector) then 133 | x = x:totable() 134 | end 135 | r[k[i]] = x 136 | end 137 | return r 138 | end 139 | 140 | function VariableTable:__tostring() 141 | local t = {} 142 | table.insert(t, '{') 143 | local k,v = self:keys(),self:values() 144 | for i=1,#k do 145 | table.insert(t, ' ' .. k[i] .. ' : ' .. tostring(v[i])) 146 | end 147 | table.insert(t, '}') 148 | table.insert(t, string.format('[%s]', torch.type(self))) 149 | return table.concat(t, '\n') 150 | end 151 | -------------------------------------------------------------------------------- /lua/master.lua: -------------------------------------------------------------------------------- 1 | --- Collection of functions to query information about the ROS master 2 | -- @module master 3 | local ffi = require 'ffi' 4 | local torch = require 'torch' 5 | local ros = require 'ros.env' 6 | local utils = require 'ros.utils' 7 | local std = ros.std 8 | 9 | local master = {} 10 | ros.master = master 11 | 12 | function init() 13 | local names = { 14 | 'execute', 15 | 'getHost', 16 | 'getPort', 17 | 'getURI', 18 | 'check', 19 | 'getTopics', 20 | 'getNodes', 21 | 'setRetryTimeout' 22 | } 23 | 24 | return utils.create_method_table("ros_Master_", names) 25 | end 26 | 27 | local f = init() 28 | 29 | --- Execute an XMLRPC call on the master. 30 | -- @tparam string method The RPC method to invoke 31 | -- @tparam std.Variable request The arguments to the RPC call 32 | -- @tparam bool wait_for_master Whether or not this call should loop until it can contact the master 33 | -- @treturn bool result true if the master is available, false otherwise. 34 | -- @treturn std.Variable response The resonse that was received. 35 | -- @treturn std.Variable payload The payload that was received. 36 | function master.execute(method, request, wait_for_master) 37 | if not torch.isTypeOf(request, std.Variable) then 38 | request = std.Variable(request) 39 | end 40 | local response, payload = std.Variable(), std.Variable() 41 | local result = f.execute(method, request:cdata(), response:cdata(), payload:cdata(), wait_for_master or false) 42 | return result, response:get(), payload:get() 43 | end 44 | 45 | --- Get the hostname where the master runs. 46 | -- @treturn string The master's hostname 47 | function master.getHost() 48 | return ffi.string(f.getHost()) 49 | end 50 | 51 | --- Get the port where the master runs. 52 | -- @treturn string The master's port. 53 | function master.getPort() 54 | return f.getPort() 55 | end 56 | 57 | --- Get the full URI to the master (eg. http://host:port/). 58 | -- @treturn string The URI of the master 59 | function master.getURI() 60 | return ffi.string(f.getURI()) 61 | end 62 | 63 | --- Check if the master is running. 64 | -- This method tries to contact the master. You can call it any time after ros::init has been called. 65 | -- The intended usage is to check whether the master is up before trying to make other 66 | -- requests (subscriptions, advertisements, etc.). 67 | -- @treturn bool true if the master is available, false otherwise. 68 | function master.check() 69 | return f.check() 70 | end 71 | 72 | --- Get the list of topics that are being published by all nodes. 73 | -- @tparam[opt] std.VariableTable output Will be filled the the topic names. If parameter is not present, the table will be created by the function 74 | -- @treturn std.VariableTable Table containing the names of all topics 75 | function master.getTopics(output) 76 | local v = output or std.VariableTable() 77 | f.getTopics(v:cdata()) 78 | return v:totable() 79 | end 80 | 81 | --- Retreives the currently-known list of nodes from the master. 82 | -- @tparam[opt] std.StringVector output Will be filled with the list of nodes. If parameter is not present, the object will be created by this function 83 | -- @treturn std.StringVector List of nodes 84 | function master.getNodes(output) 85 | local v = output or std.StringVector() 86 | f.getNodes(v:cdata()) 87 | return v 88 | end 89 | 90 | --- Set the max time this node should spend looping trying to connect to the master. 91 | -- @tparam ?number|ros.Duration _1 If number: time duration in secons, fractional number possible 92 | -- @tparam[opt] number _2 If present, _1 represends the seconds and _2 represends the nanoseconds 93 | function master.setRetryTimeout(_1, _2) 94 | local d = _1 95 | if not torch.isTypeOf(d, ros.Duration) then 96 | d = ros.Duration(_1, _2) 97 | end 98 | f.setRetryTimeout(d:get_sec(), d:get_nsec()) 99 | end 100 | -------------------------------------------------------------------------------- /lua/tf/StampedTransform.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local tf = ros.tf 6 | 7 | local moveit_msgs_StampedTransform = ros.get_msgspec('geometry_msgs/TransformStamped') 8 | local moveit_msgs_StampedPose = ros.get_msgspec('geometry_msgs/PoseStamped') 9 | 10 | local StampedTransform, parent = torch.class('tf.StampedTransform', 'tf.Transform', tf) 11 | 12 | function init() 13 | local StampedTransform_method_names = { 14 | "new", 15 | "clone", 16 | "delete", 17 | "getBasePointer", 18 | "get_stamp", 19 | "set_stamp", 20 | "get_frame_id", 21 | "set_frame_id", 22 | "get_child_frame_id", 23 | "set_child_frame_id", 24 | "setData", 25 | "eq", 26 | "toStampedTransformMsg", 27 | "toStampedPoseMsg" 28 | } 29 | 30 | return utils.create_method_table("tf_StampedTransform_", StampedTransform_method_names) 31 | end 32 | 33 | local f = init() 34 | 35 | function StampedTransform:__init(transform, stamp, frame_id, child_frame_id) 36 | transform = transform or tf.Transform() 37 | if torch.isTensor(transform) then 38 | if transform:nDimension() == 2 and transform:size(1) == 4 and transform:size(2) == 4 then 39 | transform = tf.Transform():fromTensor(transform) 40 | else 41 | error('Invalid tensor specified. 4x4 matrix expected.') 42 | end 43 | end 44 | stamp = stamp or ros.Time.getNow() 45 | self.t = f.new(transform:cdata(), stamp:cdata(), frame_id or '', child_frame_id or '') 46 | self.o = f.getBasePointer(self.t) 47 | end 48 | 49 | function StampedTransform:cdata() 50 | return self.t 51 | end 52 | 53 | function StampedTransform:clone() 54 | local c = torch.factory('tf.StampedTransform')() 55 | local _t = f.clone(self.t) 56 | rawset(c, 't', _t) 57 | rawset(c, 'o', f.getBasePointer(_t)) 58 | return c 59 | end 60 | 61 | function StampedTransform:toTransform() 62 | return tf.Transform.fromStamped(self) 63 | end 64 | 65 | function StampedTransform:get_stamp(result) 66 | result = result or ros.Time() 67 | f.get_stamp(self.t, result:cdata()) 68 | return result 69 | end 70 | 71 | function StampedTransform:set_stamp(stamp) 72 | f.set_stamp(self.t, stamp:cdata()) 73 | end 74 | 75 | function StampedTransform:get_frame_id() 76 | return ffi.string(f.get_frame_id(self.t)) 77 | end 78 | 79 | function StampedTransform:set_frame_id(frame_id) 80 | f.set_frame_id(self.t, frame_id) 81 | end 82 | 83 | function StampedTransform:get_child_frame_id() 84 | return ffi.string(f.get_child_frame_id(self.t)) 85 | end 86 | 87 | function StampedTransform:set_child_frame_id(child_frame_id) 88 | f.set_child_frame_id(self.t, child_frame_id) 89 | end 90 | 91 | function StampedTransform:setData(transform) 92 | f.setData(self.t, transform:cdata()) 93 | end 94 | 95 | function StampedTransform:toStampedTransformMsg(output) 96 | local msg_bytes = torch.ByteStorage() 97 | f.toStampedTransformMsg(self.t, msg_bytes:cdata()) 98 | local msg = output or ros.Message(moveit_msgs_StampedTransform, true) 99 | msg:deserialize(msg_bytes) 100 | return msg 101 | end 102 | 103 | function StampedTransform:toStampedPoseMsg(output) 104 | local msg_bytes = torch.ByteStorage() 105 | f.toStampedPoseMsg(self.t, msg_bytes:cdata()) 106 | local msg = output or ros.Message(moveit_msgs_StampedPose, true) 107 | msg:deserialize(msg_bytes) 108 | return msg 109 | end 110 | 111 | function StampedTransform:__eq(other) 112 | f.eq(self.t, other:cdata()) 113 | end 114 | 115 | function StampedTransform:__tostring() 116 | local s = string.format('{\n stamp: %s\n frame_id: \'%s\'\n child_frame_id: \'%s\'\n transform:\n%s\n}', 117 | tostring(self:get_stamp()), 118 | self:get_frame_id(), 119 | self:get_child_frame_id(), 120 | parent.__tostring(self) 121 | ) 122 | return s 123 | end 124 | -------------------------------------------------------------------------------- /lua/ServiceClient.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | 7 | local ServiceClient = torch.class('ros.ServiceClient', ros) 8 | local ServiceClient_ptr_ct = ffi.typeof('ros_ServiceClient *') 9 | 10 | function init() 11 | local ServiceClient_method_names = { 12 | "new", 13 | "clone", 14 | "delete", 15 | "call", 16 | "isPersistent", 17 | "getService", 18 | "waitForExistence", 19 | "exists", 20 | "shutdown", 21 | "isValid" 22 | } 23 | 24 | return utils.create_method_table("ros_ServiceClient_", ServiceClient_method_names) 25 | end 26 | 27 | local f = init() 28 | 29 | function ServiceClient:__init(service_name, service_spec, persistent, header_values, serialization_handlers) 30 | self.serialization_handlers = serialization_handlers 31 | if ffi.istype(ServiceClient_ptr_ct, service_name) then 32 | self.o = service_name 33 | self.spec = service_spec 34 | ffi.gc(self.o, f.delete) 35 | else 36 | if type(service_spec) == 'string' then 37 | service_spec = ros.SrvSpec(service_spec) 38 | end 39 | if not torch.isTypeOf(service_spec, ros.SrvSpec) then 40 | error("ServiceClient:ctor(): invalid 'service_spec' argument.") 41 | end 42 | self.spec = service_spec 43 | self.o = f.new(service_name, persistent or false, utils.cdata(header_values), self.spec:md5()) 44 | end 45 | end 46 | 47 | function ServiceClient:clone() 48 | local c = torch.factory('ros.ServiceClient')() 49 | rawset(c, 'o', f.clone(self.o)) 50 | rawset(c, 'spec', self.spec) 51 | return c 52 | end 53 | 54 | function ServiceClient:createRequest() 55 | return ros.Message(self.spec.request_spec) 56 | end 57 | 58 | function ServiceClient:call(request_msg) 59 | if not self:isValid() then 60 | error('ros.ServiceClient instance is not valid.') 61 | end 62 | 63 | local sw = ros.StorageWriter(nil, 0, self.serialization_handlers) 64 | if not torch.isTypeOf(request_msg, ros.Message) then 65 | -- support filling request message from simple value or table 66 | if type(request_msg) ~= 'table' then 67 | request_msg = { request_msg } 68 | end 69 | local req = self:createRequest() 70 | req:fillFromTable(request_msg) 71 | request_msg = req 72 | end 73 | 74 | request_msg:serialize(sw) 75 | sw:shrinkToFit() 76 | 77 | local response_serialized_msg = ros.SerializedMessage() 78 | local result = f.call(self.o, sw.storage:cdata(), response_serialized_msg:cdata(), self.spec:md5()) 79 | local response_msg 80 | if result == true then 81 | local view = response_serialized_msg:view() 82 | local storage = view:storage() or torch.ByteStorage() -- storage may be nil if response is empty messages 83 | local sr = ros.StorageReader(storage, view:storageOffset()-1, nil, nil, self.serialization_handlers) 84 | response_msg = ros.Message(self.spec.response_spec, true) 85 | response_msg:deserialize(sr, true) -- true singals not that no total length was prepended to message 86 | end 87 | 88 | return response_msg 89 | end 90 | 91 | function ServiceClient:isPersistent() 92 | return f.isPersistent(self.o) 93 | end 94 | 95 | function ServiceClient:getService() 96 | local s = std.String() 97 | f.getService(self.o, s:cdata()) 98 | return s:get() 99 | end 100 | 101 | function ServiceClient:waitForExistence(timeout) 102 | if timeout and not torch.isTypeOf(timeout, ros.Duration) then 103 | timeout = ros.Duration(timeout) 104 | end 105 | return f.waitForExistence(self.o, utils.cdata(timeout)) 106 | end 107 | 108 | function ServiceClient:exists() 109 | return f.exists(self.o) 110 | end 111 | 112 | function ServiceClient:shutdown() 113 | f.shutdown(self.o) 114 | end 115 | 116 | function ServiceClient:isValid() 117 | return f.isValid(self.o) 118 | end 119 | -------------------------------------------------------------------------------- /src/ros/torch-ros.h: -------------------------------------------------------------------------------- 1 | #ifndef torch_ros_h 2 | #define torch_ros_h 3 | 4 | extern "C" { 5 | #include 6 | } 7 | 8 | #include 9 | #include 10 | #include "../std/variable.h" 11 | 12 | #define ROSIMP(return_type, class_name, name) extern "C" return_type TH_CONCAT_4(ros_, class_name, _, name) 13 | 14 | class RosWrapperException 15 | : public std::runtime_error { 16 | public: 17 | RosWrapperException(const std::string& reason) 18 | : runtime_error(reason) { 19 | } 20 | }; 21 | 22 | 23 | inline void VariableToXmlRpcValue(const xamla::Variable &src, XmlRpc::XmlRpcValue &dst) { 24 | switch (src.get_type()) { 25 | case xamla::VariableType::Void: dst.clear(); break; 26 | case xamla::VariableType::Bool: dst = XmlRpc::XmlRpcValue(src.get_bool()); break; 27 | case xamla::VariableType::Int8: dst = XmlRpc::XmlRpcValue(static_cast(src.get_int8())); break; 28 | case xamla::VariableType::Int16: dst = XmlRpc::XmlRpcValue(static_cast(src.get_int16())); break; 29 | case xamla::VariableType::Int32: dst = XmlRpc::XmlRpcValue(static_cast(src.get_int32())); break; 30 | case xamla::VariableType::Int64: dst = XmlRpc::XmlRpcValue(static_cast(src.get_int64())); break; 31 | case xamla::VariableType::UInt8: dst = XmlRpc::XmlRpcValue(static_cast(src.get_uint8())); break; 32 | case xamla::VariableType::UInt16: dst = XmlRpc::XmlRpcValue(static_cast(src.get_uint16())); break; 33 | case xamla::VariableType::UInt32: dst = XmlRpc::XmlRpcValue(static_cast(src.get_uint32())); break; 34 | case xamla::VariableType::UInt64: dst = XmlRpc::XmlRpcValue(static_cast(src.get_uint64())); break; 35 | case xamla::VariableType::Float32: dst = XmlRpc::XmlRpcValue(src.get_float32()); break; 36 | case xamla::VariableType::Float64: dst = XmlRpc::XmlRpcValue(src.get_float64()); break; 37 | case xamla::VariableType::String: dst = XmlRpc::XmlRpcValue(src.get_string()); break; 38 | case xamla::VariableType::Vector: { 39 | const xamla::VariableVector &v = *src.get_vector(); 40 | dst.setSize(v.size()); 41 | for (size_t i=0; i < v.size(); ++i) { 42 | VariableToXmlRpcValue(v[i], dst[i]); 43 | } 44 | } break; 45 | case xamla::VariableType::Table: { 46 | const xamla::VariableTable &t = *src.get_table(); 47 | for (xamla::VariableTable::const_iterator i=t.begin(); i != t.end(); ++i) { 48 | const std::string& name = i->first; 49 | const xamla::Variable& value = i->second; 50 | VariableToXmlRpcValue(value, dst[name]); 51 | } 52 | } break; 53 | } 54 | } 55 | 56 | inline void XmlRpcValueToVariable(XmlRpc::XmlRpcValue &src, xamla::Variable &dst) { 57 | switch (src.getType()) { 58 | case XmlRpc::XmlRpcValue::TypeInvalid: dst.clear(); break; 59 | case XmlRpc::XmlRpcValue::TypeBoolean: dst.set_bool(src); break; 60 | case XmlRpc::XmlRpcValue::TypeInt: dst.set_int32(src); break; 61 | case XmlRpc::XmlRpcValue::TypeDouble: dst.set_float64(src); break; 62 | case XmlRpc::XmlRpcValue::TypeString: dst.set_string(src); break; 63 | case XmlRpc::XmlRpcValue::TypeArray: { 64 | xamla::VariableVector_ptr v(new xamla::VariableVector()); 65 | for (int i = 0; i < src.size(); ++i) { 66 | xamla::Variable x; 67 | XmlRpcValueToVariable(src[i], x); 68 | v->push_back(x); 69 | } 70 | dst.set_vector(v); 71 | } break; 72 | case XmlRpc::XmlRpcValue::TypeStruct: { 73 | xamla::VariableTable_ptr t(new xamla::VariableTable()); 74 | 75 | for (XmlRpc::XmlRpcValue::iterator i = src.begin(); i != src.end(); ++i) { 76 | const std::string &name = i->first; 77 | XmlRpc::XmlRpcValue &value = i->second; 78 | xamla::Variable v; 79 | XmlRpcValueToVariable(value, v); 80 | (*t)[name] = v; 81 | } 82 | dst.set_table(t); 83 | } break; 84 | case XmlRpc::XmlRpcValue::TypeDateTime: 85 | case XmlRpc::XmlRpcValue::TypeBase64: 86 | break; 87 | } 88 | } 89 | #endif // torch_ros_h 90 | -------------------------------------------------------------------------------- /src/ros/console.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-ros.h" 2 | #include 3 | #include 4 | 5 | namespace xamla { 6 | 7 | typedef boost::unordered_map > > NamedLoggerMap; 8 | NamedLoggerMap global_named_locs; 9 | 10 | ros::console::LogLocation *getNamedLoc(const std::string& name, int level) { 11 | if (level < 0 || level >= ::ros::console::levels::Count) 12 | throw std::out_of_range("level"); 13 | 14 | NamedLoggerMap::iterator i = global_named_locs.find(name); 15 | if (i == global_named_locs.end()) { 16 | static ros::console::LogLocation __empty = { false, false, ::ros::console::levels::Count, 0 }; 17 | 18 | boost::shared_ptr > levelLoggers(new std::vector(::ros::console::levels::Count, __empty)); 19 | for (int j = 0; j < ::ros::console::levels::Count; ++j) { 20 | ros::console::initializeLogLocation(&levelLoggers->operator[](j), name, (ros::console::Level)j); 21 | } 22 | 23 | i = global_named_locs.insert(std::make_pair(name, levelLoggers)).first; 24 | } 25 | return &i->second->operator[](level); 26 | } 27 | 28 | inline std::string getLoggerName(const char *name, bool no_default_prefix) { 29 | if (!name) 30 | return ROSCONSOLE_DEFAULT_NAME; 31 | else if (no_default_prefix) 32 | return name; 33 | else 34 | return std::string(ROSCONSOLE_NAME_PREFIX) + "." + name; 35 | } 36 | 37 | } // namespace xamla 38 | 39 | 40 | ROSIMP(const char *, Console, initialize)() { 41 | ros::console::initialize(); 42 | xamla::getNamedLoc(ROSCONSOLE_DEFAULT_NAME, (int)ros::console::levels::Info); 43 | return ROSCONSOLE_NAME_PREFIX; 44 | } 45 | 46 | ROSIMP(void, Console, shutdown)() { 47 | ros::console::shutdown(); 48 | xamla::global_named_locs.clear(); 49 | } 50 | 51 | ROSIMP(void, Console, set_logger_level)(const char *name, int level, bool no_default_prefix) { 52 | const std::string& name_ = xamla::getLoggerName(name, no_default_prefix); 53 | ros::console::setLogLocationLevel(xamla::getNamedLoc(name_, level), (ros::console::levels::Level)level); 54 | if (ros::console::set_logger_level(name_, (ros::console::levels::Level)level)) 55 | ros::console::notifyLoggerLevelsChanged(); 56 | } 57 | 58 | ROSIMP(bool, Console, get_loggers)(std::vector *names, THShortTensor *levels) { 59 | std::map loggers; 60 | if (!ros::console::get_loggers(loggers)) 61 | return false; 62 | 63 | THShortTensor *levels_ = THShortTensor_newContiguous(levels); 64 | THShortTensor_resize1d(levels_, loggers.size()); 65 | names->clear(); 66 | names->reserve(loggers.size()); 67 | 68 | std::map::const_iterator i = loggers.begin(); 69 | short *levels_data = THShortTensor_data(levels_); 70 | for (; i != loggers.end(); ++i, ++levels_data) { 71 | names->push_back(i->first); 72 | *levels_data = i->second; 73 | } 74 | THShortTensor_freeCopyTo(levels_, levels); 75 | return true; 76 | } 77 | 78 | ROSIMP(bool, Console, check_loglevel)(const char *name, int level, bool no_default_prefix) { 79 | const std::string& name_ = xamla::getLoggerName(name, no_default_prefix); 80 | ros::console::LogLocation *loc = xamla::getNamedLoc(name_, level); 81 | return loc != NULL && loc->logger_enabled_; 82 | } 83 | 84 | ROSIMP(void*, Console, get_logger)(const char *name, bool no_default_prefix) { 85 | const std::string& name_ = xamla::getLoggerName(name, no_default_prefix); 86 | return xamla::getNamedLoc(name_, ros::console::levels::Info)->logger_; 87 | } 88 | 89 | ROSIMP(void, Console, print)(void *logger, int level, const char *text, const char *file, const char *function_name, int line) { 90 | if (!logger) { 91 | logger = xamla::getNamedLoc(ROSCONSOLE_DEFAULT_NAME, level)->logger_; 92 | } 93 | 94 | ros::console::print( 95 | NULL, 96 | logger, 97 | (ros::console::levels::Level)level, 98 | file, 99 | line, 100 | function_name, 101 | "%s", 102 | text 103 | ); 104 | } 105 | -------------------------------------------------------------------------------- /lua/std/VariableVector.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | std = ros.std 6 | 7 | local VariableVector = torch.class('std.VariableVector', std) 8 | 9 | function init() 10 | local VariableVector_method_names = { 11 | "new", 12 | "delete", 13 | "clone", 14 | "size", 15 | "getAt", 16 | "setAt", 17 | "push_back", 18 | "pop_back", 19 | "clear", 20 | "insert", 21 | "erase", 22 | "empty" 23 | } 24 | 25 | return utils.create_method_table("std_VariableVector_", VariableVector_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | function VariableVector:__init(...) 31 | rawset(self, 'o', f.new()) 32 | if select("#", ...) > 0 then 33 | local x = ... 34 | if type(x) ~= 'table' then 35 | x = { ... } 36 | end 37 | self:insertFromTable(1, x) 38 | end 39 | end 40 | 41 | function VariableVector:cdata() 42 | return self.o 43 | end 44 | 45 | function VariableVector:clone() 46 | local c = torch.factory('std.VariableVector')() 47 | rawset(c, 'o', f.clone(self.o)) 48 | return c 49 | end 50 | 51 | function VariableVector:size() 52 | return f.size(self.o) 53 | end 54 | 55 | function VariableVector:__len() 56 | return self:size() 57 | end 58 | 59 | function VariableVector:__index(idx) 60 | local v = rawget(self, idx) 61 | if not v then 62 | v = VariableVector[idx] 63 | if not v and type(idx) == 'number' then 64 | local o = rawget(self, 'o') 65 | v = std.Variable() 66 | f.getAt(o, idx-1, v:cdata()) 67 | end 68 | end 69 | return v 70 | end 71 | 72 | function VariableVector:__newindex(idx, v) 73 | local o = rawget(self, 'o') 74 | if type(idx) == 'number' then 75 | if not torch.isTypeOf(v, std.Variable) then 76 | v = std.Variable(v) 77 | end 78 | f.setAt(o, idx-1, v) 79 | else 80 | rawset(self, idx, v) 81 | end 82 | end 83 | 84 | function VariableVector:push_back(value) 85 | if not torch.isTypeOf(value, std.Variable) then 86 | value = std.Variable(value) 87 | end 88 | f.push_back(self.o, value:cdata()) 89 | end 90 | 91 | function VariableVector:pop_back() 92 | local last = self[#self] 93 | f.pop_back(self.o) 94 | return last 95 | end 96 | 97 | function VariableVector:clear() 98 | f.clear(self.o) 99 | end 100 | 101 | function VariableVector:insert(pos, value, n) 102 | if not torch.isTypeOf(value, std.Variable) then 103 | value = std.Variable(value) 104 | end 105 | if pos < 1 then 106 | pos = 1 107 | elseif pos > #self+1 then 108 | pos = #self + 1 109 | end 110 | f.insert(self.o, pos-1, n or 1, value:cdata()) 111 | end 112 | 113 | function VariableVector:insertFromTable(pos, t) 114 | if type(pos) == 'table' then 115 | t = pos 116 | pos = #self + 1 117 | end 118 | pos = pos or #self + 1 119 | for _,v in pairs(t) do 120 | self:insert(pos, v) 121 | pos = pos + 1 122 | end 123 | end 124 | 125 | function VariableVector:erase(begin_pos, end_pos) 126 | f.erase(self.o, begin_pos-1, (end_pos or begin_pos + 1)-1) 127 | end 128 | 129 | function VariableVector:__pairs() 130 | return function (t, k) 131 | local i = k + 1 132 | if i > #t then 133 | return nil 134 | else 135 | local v = t[i] 136 | return i, v 137 | end 138 | end, self, 0 139 | end 140 | 141 | function VariableVector:__ipairs() 142 | return self:__pairs() 143 | end 144 | 145 | function VariableVector:totable() 146 | local t = {} 147 | for i,v in ipairs(self) do 148 | v = v:get() 149 | if torch.isTypeOf(v, std.VariableTable) then 150 | v = v:totable() 151 | elseif torch.isTypeOf(v, std.VariableVector) then 152 | v = v:totable() 153 | end 154 | table.insert(t, v) 155 | end 156 | return t 157 | end 158 | 159 | function VariableVector:__tostring() 160 | local t = {} 161 | for i,v in ipairs(self) do 162 | table.insert(t, tostring(v)) 163 | end 164 | table.insert(t, string.format('[%s of size %d]', torch.type(self), self:size())) 165 | return table.concat(t, '\n') 166 | end 167 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _utils_h 2 | #define _utils_h 3 | 4 | /* 5 | #include 6 | #include 7 | #include 8 | 9 | inline Eigen::Vector4d Tensor2Vec4d(THDoubleTensor *tensor) 10 | { 11 | if (!tensor || THDoubleTensor_nElement(tensor) < 4) 12 | throw MoveItWrapperException("A tensor with at least 4 elements was expected."); 13 | 14 | THDoubleTensor *tensor_ = THDoubleTensor_newContiguous(tensor); 15 | double* data = THDoubleTensor_data(tensor_); 16 | Eigen::Vector4d v(data[0], data[1], data[2], data[3]); 17 | THDoubleTensor_free(tensor_); 18 | return v; 19 | } 20 | 21 | inline Eigen::Vector3d Tensor2Vec3d(THDoubleTensor *tensor) 22 | { 23 | if (!tensor || THDoubleTensor_nElement(tensor) < 3) 24 | throw MoveItWrapperException("A Tensor with at least 3 elements was expected."); 25 | 26 | THDoubleTensor *tensor_ = THDoubleTensor_newContiguous(tensor); 27 | double* data = THDoubleTensor_data(tensor_); 28 | Eigen::Vector3d v(data[0], data[1], data[2]); 29 | THDoubleTensor_free(tensor_); 30 | return v; 31 | } 32 | 33 | template 34 | inline Eigen::Matrix Tensor2Mat(THDoubleTensor *tensor) 35 | { 36 | THArgCheck(tensor != NULL && tensor->nDimension == 2 && tensor->size[0] == rows && tensor->size[1] == cols, 1, "invalid tensor"); 37 | tensor = THDoubleTensor_newContiguous(tensor); 38 | Eigen::Matrix output(Eigen::Map >(THDoubleTensor_data(tensor))); 39 | THDoubleTensor_free(tensor); 40 | return output; 41 | } 42 | 43 | template void viewMatrix(Eigen::Matrix &m, THDoubleTensor *output) 44 | { 45 | // create new storage that views into the matrix 46 | THDoubleStorage* storage = NULL; 47 | if ((Eigen::Matrix::Options & Eigen::RowMajor) == Eigen::RowMajor) 48 | storage = THDoubleStorage_newWithData(m.data(), (m.rows() * m.rowStride())); 49 | else 50 | storage = THDoubleStorage_newWithData(m.data(), (m.cols() * m.colStride())); 51 | 52 | storage->flag = TH_STORAGE_REFCOUNTED; 53 | THDoubleTensor_setStorage2d(output, storage, 0, rows, m.rowStride(), cols, m.colStride()); 54 | THDoubleStorage_free(storage); // tensor took ownership 55 | } 56 | 57 | inline void viewArray(double* array, size_t length, THDoubleTensor *output) 58 | { 59 | THDoubleStorage* storage = THDoubleStorage_newWithData(array, length); 60 | storage->flag = TH_STORAGE_REFCOUNTED; 61 | THDoubleTensor_setStorage1d(output, storage, 0, length, 1); 62 | THDoubleStorage_free(storage); // tensor took ownership 63 | } 64 | 65 | template void copyMatrix(const Eigen::Matrix &m, THDoubleTensor *output) 66 | { 67 | THDoubleTensor_resize2d(output, m.rows(), m.cols()); 68 | THDoubleTensor* output_ = THDoubleTensor_newContiguous(output); 69 | // there are strange static-asserts in Eigen to disallow specifying RowMajor for vectors... 70 | Eigen::Map >(THDoubleTensor_data(output_)) = m; 71 | THDoubleTensor_freeCopyTo(output_, output); 72 | }*/ 73 | 74 | #define DECL_vector2Tensor(T, name) inline void vector2Tensor(const std::vector &v, TH##name##Tensor *output) \ 75 | { \ 76 | TH##name##Tensor_resize1d(output, v.size()); \ 77 | TH##name##Tensor *output_ = TH##name##Tensor_newContiguous(output); \ 78 | std::copy(v.begin(), v.end(), TH##name##Tensor_data(output_)); \ 79 | TH##name##Tensor_freeCopyTo(output_, output); \ 80 | } 81 | 82 | DECL_vector2Tensor(double, Double) 83 | DECL_vector2Tensor(float, Float) 84 | DECL_vector2Tensor(int, Int) 85 | 86 | #define DECL_Tensor2vector(T, name) inline void Tensor2vector(TH##name##Tensor *input, std::vector &v) \ 87 | { \ 88 | long n = TH##name##Tensor_nElement(input); \ 89 | v.resize(n); \ 90 | input = TH##name##Tensor_newContiguous(input); \ 91 | T *begin = TH##name##Tensor_data(input); \ 92 | std::copy(begin, begin + n, v.begin()); \ 93 | TH##name##Tensor_free(input); \ 94 | } 95 | 96 | DECL_Tensor2vector(double, Double) 97 | DECL_Tensor2vector(float, Float) 98 | DECL_Tensor2vector(int, Int) 99 | 100 | #endif //_utils_h 101 | -------------------------------------------------------------------------------- /lua/Time.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | 6 | local Time = torch.class('ros.Time', ros) 7 | 8 | function init() 9 | local Time_method_names = { 10 | "new", 11 | "clone", 12 | "delete", 13 | "isZero", 14 | "fromSec", 15 | "toSec", 16 | "set", 17 | "assign", 18 | "get_sec", 19 | "set_sec", 20 | "get_nsec", 21 | "set_nsec", 22 | "lt", 23 | "eq", 24 | "add_Duration", 25 | "sub", 26 | "sub_Duration", 27 | "sleepUntil", 28 | "getNow", 29 | "setNow", 30 | "waitForValid", 31 | "init", 32 | "shutdown", 33 | "useSystemTime", 34 | "isSimTime", 35 | "isSystemTime", 36 | "isValid" 37 | } 38 | 39 | return utils.create_method_table("ros_Time_", Time_method_names) 40 | end 41 | 42 | local f = init() 43 | 44 | function Time:__init(_1, _2) 45 | self.o = f.new() 46 | if _1 or _2 then 47 | self:set(_1, _2) 48 | end 49 | end 50 | 51 | function Time:cdata() 52 | return self.o 53 | end 54 | 55 | function Time:clone() 56 | local c = torch.factory('ros.Time')() 57 | rawset(c, 'o', f.clone(self.o)) 58 | return c 59 | end 60 | 61 | function Time:isZero() 62 | return f.isZero(self.o) 63 | end 64 | 65 | function Time:fromSec(t) 66 | f.fromSec(self.o, t) 67 | end 68 | 69 | function Time:toSec() 70 | return f.toSec(self.o) 71 | end 72 | 73 | function Time:set(_1, _2) 74 | if torch.isTypeOf(_1, ros.Time) then 75 | self:assign(_1) 76 | elseif not _2 then 77 | self:fromSec(_1) 78 | else 79 | f.set(self.o, _1, _2) 80 | end 81 | end 82 | 83 | function Time:get() 84 | return self:get_sec(), self:get_nsec() 85 | end 86 | 87 | function Time:assign(other) 88 | f.assign(self.o, other:cdata()) 89 | end 90 | 91 | function Time:get_sec() 92 | return f.get_sec(self.o) 93 | end 94 | 95 | function Time:set_sec(sec) 96 | f.set_sec(self.o, sec) 97 | end 98 | 99 | function Time:get_nsec() 100 | return f.get_nsec(self.o) 101 | end 102 | 103 | function Time:set_nsec(nsec) 104 | f.set_nsec(self.o, nsec) 105 | end 106 | 107 | function Time:__lt(other) 108 | return f.lt(self.o, other:cdata()) 109 | end 110 | 111 | function Time:__eq(other) 112 | return f.eq(self.o, other:cdata()) 113 | end 114 | 115 | function Time:__le(other) 116 | return f.lt(self.o, other:cdata()) or f.eq(self.o, other:cdata()) 117 | end 118 | 119 | function Time:add(d, result) 120 | result = result or self 121 | if type(d) == 'number' then 122 | d = ros.Duration(d) 123 | end 124 | f.add_Duration(self.o, d:cdata(), result:cdata()) 125 | return result 126 | end 127 | 128 | function Time:sub(x, result) 129 | if type(x) == 'number' then 130 | x = ros.Duration(x) 131 | end 132 | if torch.isTypeOf(x, ros.Time) then 133 | result = result or ros.Duration() 134 | f.sub(self.o, x:cdata(), result:cdata()) 135 | elseif torch.isTypeOf(x, ros.Duration) then 136 | result = result or self 137 | f.sub_Duration(self.o, x:cdata(), result:cdata()) 138 | else 139 | error('cannot sub from ros.Time with specified argument type') 140 | end 141 | return result 142 | end 143 | 144 | function Time:__add(d) 145 | local result = ros.Time() 146 | return self:add(d, result) 147 | end 148 | 149 | function Time:__sub(x) 150 | local result 151 | if torch.isTypeOf(x, ros.Time) then 152 | result = ros.Duration() 153 | f.sub(self.o, x:cdata(), result:cdata()) 154 | elseif torch.isTypeOf(x, ros.Duration) then 155 | result = ros.Time() 156 | f.sub_Duration(self.o, x:Duration(), result:cdata()) 157 | end 158 | return result 159 | end 160 | 161 | function Time:__tostring() 162 | return string.format("%f", self:toSec()) 163 | end 164 | 165 | -- static functions 166 | 167 | function Time.init() 168 | f.init() 169 | end 170 | 171 | function Time.shutdown() 172 | f.shutdown() 173 | end 174 | 175 | function Time.now() 176 | return Time.getNow() 177 | end 178 | 179 | function Time.getNow(result) 180 | result = result or ros.Time() 181 | f.getNow(result:cdata()) 182 | return result 183 | end 184 | 185 | function Time.setNow(time) 186 | f.setNow(time:cdata()) 187 | end 188 | 189 | function Time.sleepUntil(time) 190 | f.sleepUntil(time:cdata()) 191 | end 192 | 193 | function Time.waitForValid() 194 | f.waitForValid() 195 | end 196 | 197 | function Time.useSystemTime() 198 | return f.useSystemTime() 199 | end 200 | 201 | function Time.isSimTime() 202 | return f.isSimTime() 203 | end 204 | 205 | function Time.isSystemTime() 206 | return f.isSimTime() 207 | end 208 | 209 | function Time.isValid() 210 | return f.isValid() 211 | end 212 | 213 | Time.init() 214 | -------------------------------------------------------------------------------- /src/tf/transform_listener.cpp: -------------------------------------------------------------------------------- 1 | #include "torch-tf.h" 2 | #include 3 | 4 | TFIMP(tf::TransformListener *, TransformListener, new)() { 5 | return new tf::TransformListener(); 6 | } 7 | 8 | TFIMP(void, TransformListener, delete)(tf::TransformListener *self) { 9 | delete self; 10 | } 11 | 12 | TFIMP(void, TransformListener, clear)(tf::TransformListener *self) { 13 | self->clear(); 14 | } 15 | 16 | TFIMP(void, TransformListener, getFrameStrings)(tf::TransformListener *self, std::vector *result) { 17 | try 18 | { 19 | self->getFrameStrings(*result); 20 | } 21 | catch (std::runtime_error& e) 22 | { 23 | ROS_ERROR("Exception: [%s]", e.what()); 24 | } 25 | } 26 | 27 | TFIMP(void, TransformListener, lookupTransform)( 28 | tf::TransformListener *self, 29 | const char *target_frame, 30 | const char *source_frame, ros::Time *time, 31 | tf::StampedTransform *result 32 | ) { 33 | try 34 | { 35 | self->lookupTransform(target_frame, source_frame, *time, *result); 36 | } 37 | catch (std::runtime_error& e) 38 | { 39 | ROS_ERROR("Exception: [%s]", e.what()); 40 | } 41 | } 42 | 43 | TFIMP(bool, TransformListener, waitForTransform)( 44 | tf::TransformListener *self, 45 | const char *target_frame, 46 | const char *source_frame, 47 | ros::Time *time, 48 | ros::Duration *timeout, 49 | std::string *error_msg 50 | ) { 51 | return self->waitForTransform(target_frame, source_frame, *time, *timeout, ros::Duration(0.01), error_msg); 52 | } 53 | 54 | TFIMP(bool, TransformListener, canTransform)( 55 | tf::TransformListener *self, 56 | const char *target_frame, 57 | const char *source_frame, 58 | ros::Time *time 59 | ) { 60 | return self->canTransform(target_frame, source_frame, *time, NULL); 61 | } 62 | 63 | TFIMP(void, TransformListener, lookupTransformFull)(tf::TransformListener *self, 64 | const char *target_frame, ros::Time *target_time, 65 | const char *source_frame, ros::Time *source_time, 66 | const char *fixed_frame, tf::StampedTransform *result 67 | ) { 68 | self->lookupTransform(target_frame, *target_time, source_frame, *source_time, fixed_frame, *result); 69 | } 70 | 71 | TFIMP(bool, TransformListener, waitForTransformFull)(tf::TransformListener *self, 72 | const char *target_frame, ros::Time *target_time, 73 | const char *source_frame, ros::Time *source_time, 74 | const char *fixed_frame, ros::Duration *timeout, std::string *error_msg 75 | ) { 76 | return self->waitForTransform(target_frame, *target_time, source_frame, *source_time, fixed_frame, *timeout, ros::Duration(0.01), error_msg); 77 | } 78 | 79 | TFIMP(bool, TransformListener, canTransformFull)(tf::TransformListener *self, 80 | const char *target_frame, ros::Time *target_time, 81 | const char *source_frame, ros::Time *source_time, 82 | const char *fixed_frame 83 | ) { 84 | return self->canTransform(target_frame, *target_time, source_frame, *source_time, fixed_frame, NULL); 85 | } 86 | 87 | TFIMP(void, TransformListener, resolve)(tf::TransformListener *self, const char *frame_name, std::string *result) { 88 | *result = self->resolve(frame_name); 89 | } 90 | 91 | TFIMP(int, TransformListener, getLatestCommonTime)( 92 | tf::TransformListener *self, 93 | const char *source_frame, 94 | const char *target_frame, 95 | ros::Time *time, 96 | std::string *error_string 97 | ) { 98 | return self->getLatestCommonTime(source_frame, target_frame, *time, error_string); 99 | } 100 | 101 | TFIMP(void, TransformListener, chainAsVector)(tf::TransformListener *self, 102 | const char *target_frame, ros::Time *target_time, 103 | const char *source_frame, ros::Time *source_time, 104 | const char *fixed_frame, std::vector *result 105 | ) { 106 | self->chainAsVector(target_frame, *target_time, source_frame, *source_time, fixed_frame, *result); 107 | } 108 | 109 | TFIMP(bool, TransformListener, getParent)( 110 | tf::TransformListener *self, 111 | const char* frame_id, 112 | ros::Time *time, 113 | std::string *result 114 | ) { 115 | return self->getParent(frame_id, *time, *result); 116 | } 117 | 118 | TFIMP(bool, TransformListener, frameExists)(tf::TransformListener *self, const char *frame_id) { 119 | try 120 | { 121 | return self->frameExists(frame_id); 122 | } 123 | catch (std::runtime_error& e) 124 | { 125 | ROS_ERROR("Exception: [%s]", e.what()); 126 | return false; 127 | } 128 | } 129 | 130 | TFIMP(void, TransformListener, getCacheLength)(tf::TransformListener *self, ros::Duration *result) { 131 | *result = self->getCacheLength(); 132 | } 133 | 134 | TFIMP(void, TransformListener, getTFPrefix)(tf::TransformListener *self, std::string *result) { 135 | *result = self->getTFPrefix(); 136 | } 137 | -------------------------------------------------------------------------------- /lua/tf/TransformListener.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | local utils = require 'ros.utils' 5 | local std = ros.std 6 | local tf = ros.tf 7 | 8 | local TransformListener = torch.class('tf.TransformListener', tf) 9 | 10 | function init() 11 | local TransformListener_method_names = { 12 | "new", 13 | "delete", 14 | "clear", 15 | "getFrameStrings", 16 | "lookupTransform", 17 | "waitForTransform", 18 | "canTransform", 19 | "lookupTransformFull", 20 | "waitForTransformFull", 21 | "canTransformFull", 22 | "resolve", 23 | "getLatestCommonTime", 24 | "chainAsVector", 25 | "getParent", 26 | "frameExists", 27 | "getCacheLength", 28 | "getTFPrefix" 29 | } 30 | 31 | return utils.create_method_table("tf_TransformListener_", TransformListener_method_names) 32 | end 33 | 34 | local f = init() 35 | 36 | function TransformListener:__init() 37 | self.o = f.new() 38 | end 39 | 40 | function TransformListener:clear() 41 | f.clear(self.o) 42 | end 43 | 44 | function TransformListener:getFrameStrings() 45 | local result = std.StringVector() 46 | f.getFrameStrings(self.o, result:cdata()) 47 | return result 48 | end 49 | 50 | function TransformListener:lookupTransform(target_frame, source_frame, time, result) 51 | local result = result or tf.StampedTransform() 52 | f.lookupTransform(self.o, target_frame, source_frame, time:cdata(), result:cdata()) 53 | return result 54 | end 55 | 56 | function TransformListener:waitForTransform(target_frame, source_frame, time, timeout, may_throw) 57 | timeout = timeout or ros.Duration(10) 58 | if may_throw then 59 | local error_msg = std.String() 60 | if not f.waitForTransform(self.o, target_frame, source_frame, time:cdata(), timeout:cdata(), error_msg:cdata()) then 61 | error(error_msg:get()) 62 | end 63 | return true 64 | else 65 | return f.waitForTransform(self.o, target_frame, source_frame, time:cdata(), timeout:cdata(), ffi.NULL) 66 | end 67 | end 68 | 69 | function TransformListener:canTransform(target_frame, source_frame, time) 70 | return f.canTransform(self.o, target_frame, source_frame, time:cdata()) 71 | end 72 | 73 | function TransformListener:lookupTransformFull(target_frame, target_time, source_frame, source_time, fixed_frame, result) 74 | result = result or tf.StampedTransform() 75 | f.lookupTransformFull(self.o, target_frame, target_time:cdata(), source_frame, source_time:cdata(), fixed_frame, result:cdata()) 76 | return result 77 | end 78 | 79 | function TransformListener:waitForTransformFull(target_frame, target_time, source_frame, source_time, fixed_frame, timeout, may_throw) 80 | timeout = timeout or ros.Duration(10) 81 | if may_throw then 82 | local error_msg = std.String() 83 | if not f.waitForTransformFull(self.o, target_frame, target_time:cdata(), source_frame, source_time:cdata(), fixed_frame, timeout:cdata(), error_msg:cdata()) then 84 | error(error_msg:get()) 85 | end 86 | return true 87 | else 88 | return f.waitForTransformFull(self.o, target_frame, target_time:cdata(), source_frame, source_time:cdata(), fixed_frame, timeout:cdata(), ffi.NULL) 89 | end 90 | end 91 | 92 | function TransformListener:canTransformFull(target_frame, target_time, source_frame, source_time, fixed_frame) 93 | return f.canTransformFull(self.o, target_frame, target_time:cdata(), source_frame, source_time:cdata(), fixed_frame) 94 | end 95 | 96 | function TransformListener:resolve(frame_name) 97 | local name = std.String() 98 | f.resolve(self.o, frame_name, name:cdata()) 99 | return name:get() 100 | end 101 | 102 | function TransformListener:getLatestCommonTime(source_frame, target_frame) 103 | local result = ros.Time() 104 | local error_msg = std.String() 105 | if f.getLatestCommonTime(self.o, source_frame, target_frame, result:cdata(), error_msg:cdata()) == 0 then 106 | error(error_msg:get()) 107 | end 108 | end 109 | 110 | function TransformListener:chainAsVector(target_frame, target_time, source_frame, source_time, fixed_frame, result) 111 | result = result or std.StringVector 112 | f.chainAsVector(self.o, target_frame, target_time:cdata(), source_frame, source_time:cdata(), fixed_frame, result:cdata()) 113 | return result 114 | end 115 | 116 | function TransformListener:getParent(frame_id, time) 117 | time = time or ros.Time(0) 118 | local parent = std.String() 119 | f.getParent(self.o, frame_id, time:cdata(), parent:cdata()) 120 | return parent:get() 121 | end 122 | 123 | function TransformListener:frameExists(frame_id) 124 | return f.frameExists(self.o, frame_id) 125 | end 126 | 127 | function TransformListener:getCacheLength() 128 | local duration = ros.Duration() 129 | f.getCacheLength(self.o, duration:cdata()) 130 | return duration 131 | end 132 | 133 | function TransformListener:getTFPrefix() 134 | local prefix = std.String() 135 | f.getTFPrefix(self.o, prefix:cdata()) 136 | return prefix 137 | end 138 | -------------------------------------------------------------------------------- /lua/StorageReader.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | 5 | local StorageReader = torch.class('ros.StorageReader', ros) 6 | local SIZE_OF_UINT32 = ffi.sizeof('uint32_t') 7 | 8 | local function ensurePosReadable(self, pos) 9 | if pos < 0 or pos > self.length then 10 | error(string.format('Read position out of range (buffer size: %d, read position: %d).', self.length, pos)) 11 | end 12 | end 13 | 14 | function StorageReader:__init(storage, offset, length, byteOrder, serialization_handlers) 15 | byteOrder = byteOrder or ffi.abi('le') and 'le' or 'be' 16 | if byteOrder ~= 'le' then 17 | error('Big-endian systems not yet supported.') 18 | end 19 | 20 | if not torch.isTypeOf(storage, torch.ByteStorage) then 21 | error('argument 1: torch.ByteStorage expected') 22 | end 23 | 24 | offset = offset or 0 25 | if offset < 0 or offset > storage:size() then 26 | error('argument 2: offset outside storage bounds') 27 | end 28 | 29 | self.storage = storage 30 | self.data = storage:data() 31 | self.offset = offset or 0 32 | self.length = length or storage:size() 33 | self.length = math.min(self.length, storage:size()) 34 | self.serialization_handlers = serialization_handlers 35 | end 36 | 37 | local function createReadMethod(type) 38 | local element_size = ffi.sizeof(type) 39 | if ffi.arch == 'arm' then 40 | -- use ffi.copy() instead of plain cast on ARM to avoid bus errors 41 | local buffer = ffi.typeof(type .. '[1]')() 42 | return function(self, offset) 43 | local offset_ = offset or self.offset 44 | ensurePosReadable(self, offset_ + element_size - 1) 45 | if not offset then 46 | self.offset = self.offset + element_size 47 | end 48 | ffi.copy(buffer, self.data + offset_, element_size) 49 | return buffer[0] 50 | end 51 | else 52 | local ptr_type = ffi.typeof(type .. '*') 53 | return function(self, offset) 54 | local offset_ = offset or self.offset 55 | ensurePosReadable(self, offset_ + element_size - 1) 56 | if not offset then 57 | self.offset = self.offset + element_size 58 | end 59 | return ffi.cast(ptr_type, self.data + offset_)[0] 60 | end 61 | end 62 | end 63 | 64 | StorageReader.readInt8 = createReadMethod('int8_t') 65 | StorageReader.readInt16 = createReadMethod('int16_t') 66 | StorageReader.readInt32 = createReadMethod('int32_t') 67 | StorageReader.readInt64 = createReadMethod('int64_t') 68 | StorageReader.readUInt8 = createReadMethod('uint8_t') 69 | StorageReader.readUInt16 = createReadMethod('uint16_t') 70 | StorageReader.readUInt32 = createReadMethod('uint32_t') 71 | StorageReader.readUInt64 = createReadMethod('uint64_t') 72 | StorageReader.readFloat32 = createReadMethod('float') 73 | StorageReader.readFloat64 = createReadMethod('double') 74 | 75 | function StorageReader:readString(offset) 76 | local offset_ = offset or self.offset 77 | local length = self:readUInt32(offset) 78 | offset_ = offset_ + SIZE_OF_UINT32 79 | ensurePosReadable(self, offset_ + length - 1) 80 | if not offset then 81 | self.offset = offset_ + length 82 | end 83 | return ffi.string(self.data + offset_, length) 84 | end 85 | 86 | function StorageReader:readTensor(tensor_ctor, offset, fixed_array_size) 87 | local offset_ = offset or self.offset 88 | local n = fixed_array_size or self:readUInt32(offset_) 89 | if fixed_array_size == nil then 90 | offset_ = offset_ + SIZE_OF_UINT32 91 | end 92 | 93 | local t, sizeInBytes 94 | if tensor_ctor == torch.ByteTensor then -- special handling for ByteTensor (direct view into storage for perf reasons) 95 | sizeInBytes = n 96 | ensurePosReadable(self, offset_ + sizeInBytes - 1) 97 | t = tensor_ctor(self.storage, offset_ + 1, n) -- offset_ + 1 to handle zero based arrays (C) vs. one based arrays (lua) 98 | else 99 | t = tensor_ctor() 100 | sizeInBytes = n * t:elementSize() 101 | ensurePosReadable(self, offset_ + sizeInBytes - 1) 102 | t:resize(n) 103 | ffi.copy(t:data(), self.data + offset_, sizeInBytes) 104 | end 105 | if not offset then 106 | self.offset = offset_ + sizeInBytes 107 | end 108 | return t 109 | end 110 | 111 | function StorageReader:setOffset(offset) 112 | ensurePosReadable(self, offset) 113 | self.offset = offset 114 | end 115 | 116 | function StorageReader:getHandler(message_type) 117 | return self.serialization_handlers and self.serialization_handlers[message_type] 118 | end 119 | 120 | local function createReadTensorMethod(tensor_ctor) 121 | return function(self, offset) 122 | return self:readTensor(tensor_ctor, offset) 123 | end 124 | end 125 | 126 | StorageReader.readByteTensor = createReadTensorMethod(torch.ByteTensor) 127 | StorageReader.readIntTensor = createReadTensorMethod(torch.IntTensor) 128 | StorageReader.readShortTensor = createReadTensorMethod(torch.ShortTensor) 129 | StorageReader.readLongTensor = createReadTensorMethod(torch.LongTensor) 130 | StorageReader.readFloatTensor = createReadTensorMethod(torch.FloatTensor) 131 | StorageReader.readDoubleTensor = createReadTensorMethod(torch.DoubleTensor) 132 | -------------------------------------------------------------------------------- /demo/robotiq_c_model.lua: -------------------------------------------------------------------------------- 1 | ros = require 'ros' 2 | 3 | --[[ 4 | make sure robitq gripper node is running, e.g. for connection via modbus RTU run: 5 | 6 | rosrun robotiq_c_model_control CModelRtuNode.py /dev/ttyUSB1 7 | ]] 8 | 9 | ros.init('robotiq_c_model_demo') 10 | nodehandle = ros.NodeHandle() 11 | 12 | grippers = {} 13 | 14 | RobotiqCModel = torch.class('grippers.RobotiqCModel', grippers) 15 | 16 | local function clamp(x, lo, hi) 17 | if x < lo then 18 | return lo 19 | elseif x > hi then 20 | return hi 21 | else 22 | return x 23 | end 24 | end 25 | 26 | local CModel_robot_input_spec = ros.MsgSpec('robotiq_c_model_control/CModel_robot_input') 27 | local CModel_robot_output_spec = ros.MsgSpec('robotiq_c_model_control/CModel_robot_output') 28 | 29 | 30 | local InitStatus = { 31 | Reset = 0, -- Gripper reset. 32 | Activation = 1 -- Gripper activation. 33 | } 34 | 35 | local ActionStatus = { 36 | Stopped = 0, -- Stopped (or performing activation / automatic release). 37 | GoToPosition = 1 -- Go to Position Request. 38 | } 39 | 40 | local GripperStatus = { 41 | Reset = 0, -- Gripper is in reset ( or automatic release ) state. See Fault Status if Gripper is activated. 42 | Activating = 1, -- Activation in progress. 43 | Activated = 3 -- Activation is completed. 44 | } 45 | 46 | local ObjStatus = { 47 | Moving = 0, -- Fingers are in motion towards requested position. No object detected. 48 | DetectedOpening = 1, -- Fingers have stopped due to a contact while opening before requested position. Object detected opening. 49 | DetectedClosing = 2, -- Fingers have stopped due to a contact while closing before requested position. Object detected closing. 50 | NoObject = 3 -- Fingers are at requested position. No object detected or object has been loss / dropped. 51 | } 52 | 53 | local FaultStatus = { 54 | NoFault = 0x00, -- No fault (LED is blue) 55 | 56 | -- Priority faults (LED is blue) 57 | ActionDelayed = 0x052, -- Action delayed, activation (reactivation) must be completed prior to renewed action. 58 | Deactivated = 0x07, -- The activation bit must be set prior to action. 59 | 60 | -- Minor faults (LED continuous red) 61 | HighTemp = 0x08, -- Maximum operating temperature exceeded, wait for cool-down. 62 | 63 | -- Major faults (LED blinking red/blue) - Reset is required (rising edge on activation bit rACT needed). 64 | LowVoltage = 0x0A, -- Under minimum operating voltage. 65 | AutoReleaseBusy = 0x0B, -- Automatic release in progress. 66 | CPUFault = 0x0C, -- Internal processor fault. 67 | ActivationFault = 0x0D, -- Activation fault, verify that no interference or other error occurred. 68 | Overcurrent = 0x0E, -- Overcurrent triggered. 69 | AutoReleaseDone = 0x0F -- Automatic release completed. 70 | } 71 | 72 | function RobotiqCModel:__init(nodehandle) 73 | self.nodehandle = nodehandle 74 | self.input = self.nodehandle:subscribe("/CModelRobotInput", CModel_robot_input_spec, 100) 75 | self.last_state = ros.Message(CModel_robot_input_spec) 76 | self.publisher = nodehandle:advertise("/CModelRobotOutput", CModel_robot_output_spec, 100) 77 | if not self.publisher:waitForSubscriber(1, 5) then 78 | error('Timeout: No subscription for CModelRobotOutput.') 79 | end 80 | self:reset() 81 | end 82 | 83 | function RobotiqCModel:shutdown() 84 | if self.input then 85 | self.input:shutdown() 86 | self.input = nil 87 | end 88 | end 89 | 90 | function RobotiqCModel:reset() 91 | self.cmd = ros.Message(CModel_robot_output_spec) 92 | self.cmd.rACT = 0 93 | self.publisher:publish(self.cmd) 94 | end 95 | 96 | function RobotiqCModel:publish() 97 | self.publisher:publish(self.cmd) 98 | end 99 | 100 | function RobotiqCModel:activate() 101 | self.cmd = ros.Message(CModel_robot_output_spec) 102 | self.cmd.rACT = 1 103 | self.cmd.rGTO = 1 104 | self.cmd.rSP = 255 105 | self.cmd.rFR = 10 106 | self:publish() 107 | end 108 | 109 | function RobotiqCModel:setSpeed(speed) 110 | self.cmd.rSP = clamp(speed, 0, 255) 111 | self:publish() 112 | end 113 | 114 | function RobotiqCModel:setForce(force) 115 | self.cmd.rFR = clamp(force, 0, 255) 116 | self:publish() 117 | end 118 | 119 | function RobotiqCModel:setPosition(pos) 120 | self.cmd.rPR = clamp(pos, 0, 255) 121 | self:publish() 122 | end 123 | 124 | function RobotiqCModel:open() 125 | self:setPosition(0) 126 | end 127 | 128 | function RobotiqCModel:close() 129 | self:setPosition(255) 130 | end 131 | 132 | function RobotiqCModel:spin() 133 | -- process available messages 134 | while self.input:hasMessage() do 135 | self.input:read(100, self.last_state) 136 | end 137 | end 138 | 139 | function RobotiqCModel:__tostring() 140 | return tostring(self.last_state) 141 | end 142 | 143 | spinner = ros.AsyncSpinner() 144 | spinner:start() 145 | 146 | gripper = grippers.RobotiqCModel(nodehandle) 147 | 148 | gripper:reset() 149 | gripper:activate() 150 | gripper:setSpeed(5) 151 | gripper:setForce(0) 152 | ros.spinOnce() 153 | gripper:open() 154 | ros.spinOnce() 155 | sys.sleep(2) 156 | gripper:close() 157 | ros.spinOnce() 158 | sys.sleep(2) 159 | gripper:shutdown() 160 | -------------------------------------------------------------------------------- /lua/CallbackQueue.lua: -------------------------------------------------------------------------------- 1 | --- Callback queue to handle the callbacks within ROS 2 | -- @classmod CallbackQueue 3 | 4 | local ffi = require 'ffi' 5 | local torch = require 'torch' 6 | local ros = require 'ros.env' 7 | local utils = require 'ros.utils' 8 | 9 | local CallbackQueue = torch.class('ros.CallbackQueue', ros) 10 | 11 | function init() 12 | local CallbackQueue_method_names = { 13 | 'new', 14 | 'delete', 15 | 'callOne', 16 | 'callAvailable', 17 | 'waitCallAvailable', 18 | 'isEmpty', 19 | 'clear', 20 | 'enable', 21 | 'disable', 22 | 'isEnabled' 23 | } 24 | 25 | return utils.create_method_table("ros_CallbackQueue_", CallbackQueue_method_names) 26 | end 27 | 28 | local f = init() 29 | 30 | --- Constructor 31 | -- @tparam[opt=true] bool enabled indicating if the queue is enabled or disabled after construction 32 | function CallbackQueue:__init(enabled) 33 | self.o = f.new(enabled or true) 34 | self.spin_callbacks = { {}, {}, {}, {}, {} } 35 | end 36 | 37 | --- Get the underlying c++ instance 38 | -- @return pointer to the c++ instance 39 | function CallbackQueue:cdata() 40 | return self.o 41 | end 42 | 43 | --- Register a function for callback 44 | -- @tparam func fn The function to execute as callback function 45 | -- @tparam[opt=1] int round Prioriy/order of the callback. If unsure which value to use, omit the parameter 46 | function CallbackQueue:registerSpinCallback(fn, round) 47 | local list = self.spin_callbacks[round or 1] 48 | if utils.indexOf(list, fn) < 0 then 49 | list[#list+1] = fn 50 | end 51 | end 52 | 53 | --- Remove a function from the callback list 54 | -- @tparam func fn The function to be removed 55 | -- @tparam[opt=1] int round ; Prioriy/order of the when added. Has to be the same value as used for registerSpinCallback() 56 | function CallbackQueue:unregisterSpinCallback(fn, round) 57 | local list = self.spin_callbacks[round or 1] 58 | local i = utils.indexOf(list, fn) 59 | if i >= 0 then 60 | table.remove(list, i) 61 | end 62 | end 63 | 64 | --- Trigger all callbacks one time 65 | function CallbackQueue:callSpinCallbacks() 66 | for i,cbs in ipairs(self.spin_callbacks) do 67 | local isolation_copy = utils.cloneList(cbs) -- changes of spin_callbacks become effecitve after iteration 68 | for _,f in ipairs(isolation_copy) do 69 | if not ros.ok() then return end 70 | f() 71 | end 72 | end 73 | end 74 | jit.off(CallbackQueue.callSpinCallbacks) 75 | 76 | --- Pop a single callback off the front of the queue and invoke it. If the callback was not ready to be called, pushes it back onto the queue. 77 | -- @tparam ?number|ros:Duration timeout Timeout for the callback, either in seconds (fractional numbers like 1.5 possible) or as an instance of ros:Duration 78 | function CallbackQueue:callOne(timeout) 79 | if timeout and not torch.isTypeOf(timeout, ros.Duration) then 80 | timeout = ros.Duration(timeout) 81 | end 82 | 83 | local result = f.callOne(self.o, utils.cdata(timeout)) 84 | 85 | return result 86 | end 87 | jit.off(CallbackQueue.callOne) 88 | 89 | --- Invoke all callbacks currently in the queue. 90 | -- If a callback was not ready to be called, pushes it back onto the queue. 91 | -- @tparam ?number|ros:Duration timeout Timeout for the callback, either in seconds (fractional numbers like 1.5 possible) or as an instance of ros:Duration 92 | -- @tparam bool no_spin_callbacks 93 | function CallbackQueue:callAvailable(timeout, no_spin_callbacks) 94 | if timeout and not torch.isTypeOf(timeout, ros.Duration) then 95 | timeout = ros.Duration(timeout) 96 | end 97 | f.callAvailable(self.o, utils.cdata(timeout)) 98 | 99 | if not no_spin_callbacks then 100 | self:callSpinCallbacks() 101 | end 102 | end 103 | jit.off(CallbackQueue.callAvailable) 104 | 105 | --- Waits until the next callback is enqueued or the timeout is expired 106 | -- @tparam ?number|ros:Duration timeout Timeout for the callback, either in seconds (fractional numbers like 1.5 possible) or as an instance of ros:Duration 107 | -- @treturn bool true if a callback is enqueued, false if the queue is disabled or the timeout expired 108 | function CallbackQueue:waitCallAvailable(timeout) 109 | if timeout and not torch.isTypeOf(timeout, ros.Duration) then 110 | timeout = ros.Duration(timeout) 111 | end 112 | return f.waitCallAvailable(self.o, utils.cdata(timeout)) 113 | end 114 | jit.off(CallbackQueue.waitCallAvailable) 115 | 116 | --- Returns whether or not the queue is empty 117 | -- @treturn bool true if queue is empty, false otherwise 118 | function CallbackQueue:isEmpty() 119 | return f.isEmpty(self.o) 120 | end 121 | 122 | --- Removes all callbacks from the queue. Does not wait for calls currently in progress to finish. 123 | function CallbackQueue:clear() 124 | f.clear(self.o) 125 | end 126 | 127 | --- Enable the queue 128 | function CallbackQueue:enable() 129 | f.enable(self.o) 130 | end 131 | 132 | --- Disable the queue, meaning any calls to addCallback() will have no effect. 133 | function CallbackQueue:disable() 134 | f.disable(self.o) 135 | end 136 | 137 | --- Returns whether or not this queue is enabled. 138 | -- @treturn bool true if the queue is empty 139 | function CallbackQueue:isEnabled() 140 | return f.isEnabled(self.o) 141 | end 142 | 143 | function CallbackQueue:__tostring() 144 | return string.format("CallbackQueue {empty: %s, enabled: %s}", self:isEmpty(), self:isEnabled()) 145 | end 146 | 147 | ros.DEFAULT_CALLBACK_QUEUE = ros.CallbackQueue() 148 | -------------------------------------------------------------------------------- /lua/Publisher.lua: -------------------------------------------------------------------------------- 1 | --- Manages an advertisement on a specific topic. 2 | -- A Publisher should always be created through a call to 3 | -- NodeHandle::advertise(), or copied from one that was. Once all 4 | -- copies of a specific Publisher go out of scope, any subscriber 5 | -- status callbacks associated with that handle will stop being 6 | -- called. Once all Publishers for a given topic go out of scope the 7 | -- topic will be unadvertised. 8 | -- @classmod Publisher 9 | local ffi = require 'ffi' 10 | local torch = require 'torch' 11 | local ros = require 'ros.env' 12 | local utils = require 'ros.utils' 13 | local tf = ros.tf 14 | 15 | local Publisher = torch.class('ros.Publisher', ros) 16 | local Publisher_ptr_ct = ffi.typeof('ros_Publisher *') 17 | 18 | function init() 19 | local Publisher_method_names = { 20 | 'clone', 21 | 'delete', 22 | 'shutdown', 23 | 'getTopic', 24 | 'getNumSubscribers', 25 | 'isLatched', 26 | 'publish' 27 | } 28 | 29 | return utils.create_method_table('ros_Publisher_', Publisher_method_names) 30 | end 31 | 32 | local f = init() 33 | 34 | --- Constructor. 35 | -- A Publisher should always be created through a call to 36 | -- NodeHandle::advertise(), or copied from one that was. Therefore parameters are not documented in detail 37 | -- @param ptr 38 | -- @param msg_spec 39 | -- @param connect_cb 40 | -- @param disconnect_cb 41 | function Publisher:__init(ptr, msg_spec, connect_cb, disconnect_cb, serialization_handlers) 42 | if not ptr or not ffi.typeof(ptr) == Publisher_ptr_ct then 43 | error('argument 1: ros::Publisher * expected.') 44 | end 45 | self.o = ptr 46 | self.msg_spec = msg_spec 47 | self.connect_cb = connect_cb 48 | self.disconnect_cb = disconnect_cb 49 | self.serialization_handlers = serialization_handlers 50 | 51 | ffi.gc(ptr, 52 | function(p) 53 | f.delete(p) 54 | if self.connect_cb ~= nil then 55 | self.connect_cb:free() -- free connect callback 56 | self.connect_cb = nil 57 | end 58 | if self.disconnect_cb ~= nil then 59 | self.disconnect_cb:free() -- free disconnet callback 60 | self.disconnect_cb = nil 61 | end 62 | end 63 | ) 64 | 65 | end 66 | 67 | --- Get the cdata of this object 68 | function Publisher:cdata() 69 | return self.o 70 | end 71 | 72 | --- Create a deep copy 73 | -- @treturn ros.Publisher The new object 74 | function Publisher:clone() 75 | local c = torch.factory('ros.Publisher')() 76 | rawset(c, 'o', f.clone(self.o)) 77 | rawset(c, 'msg_spec', self.msg_spec) 78 | return c 79 | end 80 | 81 | --- Shutdown the publisher. 82 | -- This method usually does not need to be explicitly called, as 83 | -- automatic shutdown happens when all copies of this Publisher go out 84 | -- of scope 85 | function Publisher:shutdown() 86 | f.shutdown(self.o) 87 | self.o = nil 88 | end 89 | 90 | --- Returns the topic that this Publisher will publish on. 91 | -- @treturn string The topic that this Publisher will publish on. 92 | function Publisher:getTopic() 93 | local s = std.String() 94 | f.getTopic(self.o, s:cdata()) 95 | return s:get() 96 | end 97 | 98 | --- Number of subscribers of this publisher. 99 | -- @treturn int Number of subscribers 100 | function Publisher:getNumSubscribers() 101 | return f.getNumSubscribers(self.o) 102 | end 103 | 104 | --- Returns whether or not this topic is latched. 105 | -- @treturn bool whether or not this topic is latched 106 | function Publisher:isLatched() 107 | return f.isLatched(self.o) 108 | end 109 | 110 | --- Publish the given message 111 | -- @tparam ros.Message msg The message to publish 112 | function Publisher:publish(msg) 113 | local sw = ros.StorageWriter(nil, 0, self.serialization_handlers) 114 | if torch.isTypeOf(msg, ros.Message) then 115 | -- serialize message to storage writer 116 | msg:serialize(sw) 117 | else 118 | -- get serialization handler by message type 119 | local handler = sw:getHandler(self.msg_spec.type) 120 | if handler == nil then 121 | error('No serialization handler defined for custom message type') 122 | end 123 | 124 | local offset = sw.offset 125 | sw:writeUInt32(0) -- reserve space for message size 126 | handler:write(sw, msg) 127 | sw:writeUInt32(sw.offset - 4, offset) 128 | end 129 | 130 | sw:shrinkToFit() 131 | 132 | --[[ debug 133 | print('Publisher:publish(msg)') 134 | print('sending:') 135 | print(sw.storage) ]] 136 | 137 | f.publish(self.o, sw.storage:cdata(), 0, sw.length) 138 | end 139 | 140 | --- Wait for subscribers 141 | -- @tparam int min_count Minimum numbers of subscribers to wait for 142 | -- @tparam ?ros.Duration|number Maximum number of seconds to wait for subscribers 143 | -- @treturn bool true if the number of subscribers is reached, false if timed out 144 | function Publisher:waitForSubscriber(min_count, timeout) 145 | if not ros.Time.isValid() then 146 | ros.Time.init() 147 | end 148 | 149 | min_count = min_count or 1 150 | if timeout and not torch.isTypeOf(timeout, ros.Duration) then 151 | timeout = ros.Duration(timeout) 152 | end 153 | 154 | local start = ros.Time.getNow() 155 | while true do 156 | if timeout and (ros.Time.getNow() - start) > timeout then 157 | return false 158 | elseif self:getNumSubscribers() >= min_count then 159 | return true 160 | end 161 | ros.spinOnce() 162 | sys.sleep(0.001) 163 | end 164 | end 165 | 166 | --- 167 | function Publisher:createMessage() 168 | return ros.Message(self.msg_spec) 169 | end 170 | -------------------------------------------------------------------------------- /lua/actionlib/ServerGoalHandle.lua: -------------------------------------------------------------------------------- 1 | local ros = require 'ros.env' 2 | local GoalStatus = require 'ros.actionlib.GoalStatus' 3 | local std = ros.std 4 | local actionlib = ros.actionlib 5 | 6 | 7 | local ServerGoalHandle = torch.class('ros.actionlib.ServerGoalHandle', actionlib) 8 | 9 | 10 | local function ServerGoalHandle_setGoalStatus(self, status, text) 11 | self.goal_status.status = status 12 | self.goal_status.text = text or '' 13 | self.action_server:publishStatus() 14 | end 15 | 16 | 17 | local function ServerGoalHandle_setGoalResult(self, status, text, result) 18 | self.goal_status.status = status 19 | self.goal_status.text = text or '' 20 | self.action_server:publishResult(self.goal_status, result) 21 | self.handle_destruction_time = ros.Time.now() 22 | end 23 | 24 | 25 | function ServerGoalHandle:__init(action_server, goal_id, status, goal) 26 | self.action_server = action_server 27 | self.goal_status = ros.Message('actionlib_msgs/GoalStatus') -- http://docs.ros.org/jade/api/actionlib_msgs/html/msg/GoalStatus.html 28 | self.goal_status.goal_id:assign(goal_id) 29 | 30 | if self.goal_status.goal_id.stamp == ros.Time() then 31 | self.goal_status.goal_id.stamp = ros.Time.now() 32 | end 33 | 34 | self.goal_status.status = status 35 | self.goal = goal 36 | end 37 | 38 | 39 | function ServerGoalHandle:createResult() 40 | return self.action_server:createResult() 41 | end 42 | 43 | 44 | function ServerGoalHandle:createFeeback() 45 | return self.action_server:createFeeback() 46 | end 47 | 48 | 49 | function ServerGoalHandle:setAccepted(text) 50 | text = text or '' 51 | ros.DEBUG_NAMED("actionlib", "Accepting goal, id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 52 | if self.goal_status.status == GoalStatus.PENDING then 53 | ServerGoalHandle_setGoalStatus(self, GoalStatus.ACTIVE, text) 54 | elseif self.goal_status.status == GoalStatus.RECALLING then 55 | ServerGoalHandle_setGoalStatus(self, GoalStatus.PREEMPTING, text) 56 | else 57 | ros.ERROR_NAMED("actionlib", "To transition to an active state, the goal must be in a pending or recalling state, it is currently in state: %d", 58 | self.goal_status.status) 59 | end 60 | end 61 | 62 | 63 | function ServerGoalHandle:setCanceled(result, text) 64 | text = text or '' 65 | ros.DEBUG_NAMED("actionlib", "Setting status to canceled on goal, id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 66 | if self.goal_status.status == GoalStatus.PENDING or self.goal_status.status == GoalStatus.RECALLING then 67 | ServerGoalHandle_setGoalResult(self, GoalStatus.RECALLED, text, result) 68 | elseif self.goal_status.status == GoalStatus.ACTIVE or self.goal_status.status == GoalStatus.PREEMPTING then 69 | ServerGoalHandle_setGoalResult(self, GoalStatus.PREEMPTED, text, result) 70 | else 71 | ros.ERROR_NAMED("actionlib", "To transition to a cancelled state, the goal must be in a pending, recalling, active, or preempting state, it is currently in state: %d", 72 | self.goal_status.status) 73 | end 74 | end 75 | 76 | 77 | function ServerGoalHandle:setRejected(result, text) 78 | text = text or '' 79 | ros.DEBUG_NAMED("actionlib", "Setting status to rejected on goal, id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 80 | if self.goal_status.status == GoalStatus.PENDING or self.goal_status.status == GoalStatus.RECALLING then 81 | ServerGoalHandle_setGoalResult(self, GoalStatus.REJECTED, text, result) 82 | else 83 | ros.ERROR_NAMED("actionlib", "To transition to a rejected state, the goal must be in a pending or recalling state, it is currently in state: %d", 84 | self.goal_status.status) 85 | end 86 | end 87 | 88 | 89 | function ServerGoalHandle:setAborted(result, text) 90 | text = text or '' 91 | ros.DEBUG_NAMED("actionlib", "Setting status to aborted on goal, id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 92 | if self.goal_status.status == GoalStatus.PREEMPTING or self.goal_status.status == GoalStatus.ACTIVE then 93 | ServerGoalHandle_setGoalResult(self, GoalStatus.ABORTED, text, result) 94 | else 95 | ros.ERROR_NAMED("actionlib", "To transition to an aborted state, the goal must be in a preempting or active state, it is currently in state: %d", 96 | self.goal_status.status) 97 | end 98 | end 99 | 100 | 101 | function ServerGoalHandle:setSucceeded(result, text) 102 | text = text or '' 103 | ros.DEBUG_NAMED("actionlib", "Setting status to succeeded on goal, id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 104 | if self.goal_status.status == GoalStatus.PREEMPTING or self.goal_status.status == GoalStatus.ACTIVE then 105 | ServerGoalHandle_setGoalResult(self, GoalStatus.SUCCEEDED, text, result) 106 | else 107 | ros.ERROR_NAMED("actionlib", "To transition to a succeeded state, the goal must be in a preempting or active state, it is currently in state: %d", 108 | self.goal_status.status) 109 | end 110 | end 111 | 112 | 113 | function ServerGoalHandle:publishFeedback(feedback) 114 | self.action_server:publishFeedback(self.goal_status, feedback) 115 | end 116 | 117 | 118 | function ServerGoalHandle:getGoal() 119 | return self.goal 120 | end 121 | 122 | 123 | function ServerGoalHandle:getGoalID() 124 | return self.goal_status.goal_id 125 | end 126 | 127 | 128 | function ServerGoalHandle:getGoalStatus() 129 | return self.goal_status 130 | end 131 | 132 | 133 | function ServerGoalHandle:setCancelRequested() 134 | ros.DEBUG_NAMED("actionlib", "Transisitoning to a cancel requested state on goal id: %s, stamp: %.2f", self:getGoalID().id, self:getGoalID().stamp:toSec()) 135 | if self.goal_status.status == GoalStatus.PENDING then 136 | ServerGoalHandle_setGoalStatus(self, GoalStatus.RECALLING, 'RECALLING') 137 | return true 138 | end 139 | if self.goal_status.status == GoalStatus.ACTIVE then 140 | ServerGoalHandle_setGoalStatus(self, GoalStatus.PREEMPTING, 'PREEMPTING') 141 | return true 142 | end 143 | return false 144 | end 145 | -------------------------------------------------------------------------------- /lua/Duration.lua: -------------------------------------------------------------------------------- 1 | --- ROS data class to handle time intervals and durations 2 | -- @classmod Duration 3 | 4 | local ffi = require 'ffi' 5 | local torch = require 'torch' 6 | local ros = require 'ros.env' 7 | local utils = require 'ros.utils' 8 | 9 | local Duration = torch.class('ros.Duration', ros) 10 | 11 | function init() 12 | local Duration_method_names = { 13 | "new", 14 | "clone", 15 | "delete", 16 | "set", 17 | "assign", 18 | "get_sec", 19 | "set_sec", 20 | "get_nsec", 21 | "set_nsec", 22 | "add", 23 | "sub", 24 | "mul", 25 | "eq", 26 | "lt", 27 | "toSec", 28 | "fromSec", 29 | "fromNSec", 30 | "isZero", 31 | "sleep" 32 | } 33 | 34 | return utils.create_method_table("ros_Duration_", Duration_method_names) 35 | end 36 | 37 | local f = init() 38 | 39 | --- Construct a time duration 40 | -- @tparam ?number|ros.Duration _1 If number: time duration in secons, fractional number possible 41 | -- @tparam[opt] number _2 If present, _1 represends the seconds and _2 represends the nanoseconds 42 | function Duration:__init(_1, _2) 43 | self.o = f.new() 44 | if _1 or _2 then 45 | self:set(_1, _2) 46 | end 47 | end 48 | 49 | --- Get the underlying data structure 50 | -- @return 51 | function Duration:cdata() 52 | return self.o 53 | end 54 | 55 | --- Creates a deep copy of the object 56 | -- @return A copy of the object 57 | function Duration:clone() 58 | local c = torch.factory('ros.Duration')() 59 | rawset(c, 'o', f.clone(self.o)) 60 | return c 61 | end 62 | 63 | --- Set the time value 64 | -- @tparam ?number|ros.Duration _1 If number: time duration in secons, fractional number possible 65 | -- @tparam[opt] number _2 If present, _1 represends the seconds and _2 represends the nanoseconds 66 | function Duration:set(_1, _2) 67 | if torch.isTypeOf(_1, ros.Duration) then 68 | self:assign(_1) 69 | elseif not _2 then 70 | self:fromSec(_1) 71 | else 72 | f.set(self.o, _1, _2) 73 | end 74 | end 75 | 76 | --- Get the time value as seconds and nanoseconds 77 | -- @treturn number seconds 78 | -- @treturn number nanoseconds 79 | function Duration:get() 80 | return self:get_sec(), self:get_nsec() 81 | end 82 | 83 | --- Assign operation 84 | -- @tparam ros.Duration other Instance to set the values from 85 | function Duration:assign(other) 86 | f.assign(self.o, other:cdata()) 87 | end 88 | 89 | --- Get the seconds part of the stored duration 90 | -- @treturn number seconds 91 | function Duration:get_sec() 92 | return f.get_sec(self.o) 93 | end 94 | 95 | --- Set duration in seconds 96 | -- @tparam number sec seconds 97 | function Duration:set_sec(sec) 98 | f.set_sec(self.o, sec) 99 | end 100 | 101 | --- Get the nanoseconds part of the stored duration 102 | -- @treturn number nanoseconds 103 | function Duration:get_nsec() 104 | return f.get_nsec(self.o) 105 | end 106 | 107 | --- Set the nanoseconds part of the duration 108 | -- @tparam number nsec nanoseconds 109 | function Duration:set_nsec(nsec) 110 | f.set_nsec(self.o, nsec) 111 | end 112 | 113 | --- Add two durations and return the result 114 | -- @tparam ?number|ros:Duration other The value to add, either the number of seconds (fractional numbers possible) or a ros:Duration 115 | -- @tparam[opt] ros:Duration result If presend, the sum is assigned to result. Otherwise the operation is performed inplace. 116 | function Duration:add(other, result) 117 | result = result or self 118 | if type(other) == 'number' then 119 | other = ros.Duration(other) 120 | end 121 | f.add(self.o, other:cdata(), result:cdata()) 122 | return result 123 | end 124 | 125 | --- Substract a duration from this and return the result 126 | -- @tparam ?number|ros:Duration other The value to substract, either the number of seconds (fractional numbers possible) or a ros:Duration 127 | -- @tparam[opt] ros:Duration result If presend, the difference is assigned to result. Otherwise the operation is performed inplace. 128 | function Duration:sub(other, result) 129 | result = result or self 130 | if type(other) == 'number' then 131 | other = ros.Duration(other) 132 | end 133 | f.sub(self.o, other:cdata(), result:cdata()) 134 | return result 135 | end 136 | 137 | --- Multiply a duration to this and return the result 138 | -- @tparam ?number|ros:Duration factor The value to multiply with, either the number of seconds (fractional numbers possible) or a ros:Duration 139 | -- @tparam[opt] ros:Duration result If presend, the product is assigned to result. Otherwise the operation is performed inplace. 140 | function Duration:mul(factor, result) 141 | result = result or self 142 | f.mul(self.o, factor, result:cdata()) 143 | return result 144 | end 145 | 146 | function Duration:__mul(f) 147 | local result = ros.Duration() 148 | return self:mul(f, result) 149 | end 150 | 151 | function Duration:__sub(d) 152 | local result = ros.Duration() 153 | return self:sub(d, result) 154 | end 155 | 156 | function Duration:__add(d) 157 | local result = ros.Duration() 158 | return self:add(d, result) 159 | end 160 | 161 | function Duration:__eq(other) 162 | return self ~= nil and other ~=nil and f.eq(self.o, other:cdata()) 163 | end 164 | 165 | function Duration:__lt(other) 166 | return self ~= nil and other ~= nil and f.lt(self.o, other:cdata()) 167 | end 168 | 169 | function Duration:__le(other) 170 | return self ~= nil and other ~= nil and (f.lt(self.o, other:cdata()) or f.eq(self.o, other:cdata())) 171 | end 172 | 173 | --- Converts the stored duration to seconds with fractional part 174 | -- @treturn number Number of seconds 175 | function Duration:toSec() 176 | return f.toSec(self.o) 177 | end 178 | 179 | --- Stores a (fractional) number of seconds in this data structure 180 | -- @tparam number sec Number of seconds 181 | function Duration:fromSec(sec) 182 | f.fromSec(self.o, sec) 183 | end 184 | 185 | --- Check if the Duration is zero 186 | -- @treturn bool 187 | function Duration:isZero() 188 | return f.isZero(self.o, sec) 189 | end 190 | 191 | --- Sleep for the (fractional) number of seconds stored in this data structure 192 | function Duration:sleep() 193 | f.sleep(self.o) 194 | end 195 | 196 | function Duration:__tostring() 197 | return string.format("%f", self:toSec()) 198 | end 199 | -------------------------------------------------------------------------------- /lua/StorageWriter.lua: -------------------------------------------------------------------------------- 1 | local ffi = require 'ffi' 2 | local torch = require 'torch' 3 | local ros = require 'ros.env' 4 | 5 | local StorageWriter = torch.class('ros.StorageWriter', ros) 6 | local SIZE_OF_UINT32 = ffi.sizeof('uint32_t') 7 | 8 | local function ensurePosWriteable(self, pos, growth_factor) 9 | growth_factor = growth_factor or 2 10 | if self.capacity < pos then 11 | self:setCapacity(math.max(self.capacity * growth_factor, pos)) 12 | end 13 | self.length = math.max(self.length, pos) 14 | end 15 | 16 | function StorageWriter:__init(storage, offset, serialization_handlers) -- offset is zero based 17 | if not ffi.abi('le') then 18 | error('Big-endian systems not yet supported.') 19 | end 20 | 21 | self.offset = offset or 0 22 | self.serialization_handlers = serialization_handlers 23 | self.length = self.offset 24 | if not storage then 25 | self.storage = torch.ByteStorage() 26 | self:setCapacity(16) 27 | else 28 | if not torch.isTypeOf(storage, torch.ByteStorage) and not torch.isTypeOf(storage, ros.SerializedMessage) then 29 | error('argument 1: torch.ByteStorage expected') 30 | end 31 | self.storage = storage 32 | self.capacity = storage:size() 33 | self.data = self.storage:data() 34 | end 35 | end 36 | 37 | function StorageWriter:data() 38 | return self.storage:data() 39 | end 40 | 41 | function StorageWriter:getCapacity() 42 | return self.capacity 43 | end 44 | 45 | function StorageWriter:setCapacity(capacity) 46 | self.storage:resize(capacity) 47 | self.data = self.storage:data() 48 | self.capacity = capacity 49 | end 50 | 51 | function StorageWriter:storageChanged(newOffset) 52 | self.capacity = self.storage:size() 53 | self.offset = newOffset 54 | self.length = math.max(self.length, newOffset) 55 | self.data = self.storage:data() 56 | end 57 | 58 | function StorageWriter:shrinkToFit() 59 | self:setCapacity(self.length) 60 | end 61 | 62 | function StorageWriter:setLength(length, shrinkToFit) 63 | if not length or length < 0 then 64 | error('argument 1: positive length expected') 65 | end 66 | ensurePosWriteable(self, length, shrinkToFit and 0 or 1.5) 67 | end 68 | 69 | local function createWriteMethod(type) 70 | local element_size = ffi.sizeof(type) 71 | if ffi.arch == 'arm' then 72 | -- use ffi.copy() instead of plain cast on ARM to avoid bus errors 73 | local buffer = ffi.typeof(type .. '[1]')() 74 | return function(self, value, offset) 75 | local offset_ = offset or self.offset 76 | ensurePosWriteable(self, offset_ + element_size) 77 | buffer[0] = value 78 | ffi.copy(self.data + offset_, buffer, element_size) 79 | if not offset then 80 | self.offset = self.offset + element_size 81 | end 82 | end 83 | else 84 | local ptr_type = ffi.typeof(type .. '*') 85 | local element_size = ffi.sizeof(type) 86 | return function(self, value, offset) 87 | local offset_ = offset or self.offset 88 | ensurePosWriteable(self, offset_ + element_size) 89 | ffi.cast(ptr_type, self.data + offset_)[0] = value 90 | if not offset then 91 | self.offset = self.offset + element_size 92 | end 93 | end 94 | end 95 | end 96 | 97 | StorageWriter.writeInt8 = createWriteMethod('int8_t') 98 | StorageWriter.writeInt16 = createWriteMethod('int16_t') 99 | StorageWriter.writeInt32 = createWriteMethod('int32_t') 100 | StorageWriter.writeInt64 = createWriteMethod('int64_t') 101 | StorageWriter.writeUInt8 = createWriteMethod('uint8_t') 102 | StorageWriter.writeUInt16 = createWriteMethod('uint16_t') 103 | StorageWriter.writeUInt32 = createWriteMethod('uint32_t') 104 | StorageWriter.writeUInt64 = createWriteMethod('uint64_t') 105 | StorageWriter.writeFloat32 = createWriteMethod('float') 106 | StorageWriter.writeFloat64 = createWriteMethod('double') 107 | 108 | function StorageWriter:writeString(value) 109 | if type(value) ~= 'string' then 110 | error('argument 1: string expected') 111 | end 112 | ensurePosWriteable(self, self.offset + SIZE_OF_UINT32 + #value) 113 | self:writeUInt32(#value) -- write length of string 114 | ffi.copy(self.data + self.offset, value, #value) -- copy string value 115 | self.offset = self.offset + #value 116 | end 117 | 118 | function StorageWriter:writeTensor(value, fixed_array_size) 119 | -- only tensors with a single dimension are supported for now (sufficient for ROS array support) 120 | if not torch.isTensor(value) or value:nDimension() > 1 then 121 | error('argument 1: tensor with one dimension expected') 122 | end 123 | 124 | value = value:contiguous() -- ensure we are dealing with a contiguous piece of memory 125 | 126 | local n = value:nElement() 127 | local sizeInBytes = n * value:elementSize() 128 | 129 | if fixed_array_size == nil then 130 | ensurePosWriteable(self, self.offset + SIZE_OF_UINT32 + sizeInBytes) 131 | self:writeUInt32(n) -- length of array 132 | else 133 | if n ~= fixed_array_size then 134 | error(string.format('Wrong number of elements in fixed size array (expected: %d; actual: %d).', fixed_array_size, n)) 135 | end 136 | ensurePosWriteable(self, self.offset + sizeInBytes) 137 | end 138 | ffi.copy(self.data + self.offset, value:data(), sizeInBytes) -- binary data 139 | self.offset = self.offset + sizeInBytes 140 | end 141 | 142 | local function createTypedWriteTensorMethod(tensor_ctor) 143 | return function(self, value) 144 | if not torch.isTypeOf(value, tensor_ctor) then 145 | error('argument 1: tensor has unexpeted type') 146 | end 147 | self:writeTensor(value) 148 | end 149 | end 150 | 151 | StorageWriter.writeByteTensor = createTypedWriteTensorMethod(torch.ByteTensor) 152 | StorageWriter.writeShortTensor = createTypedWriteTensorMethod(torch.ShortTensor) 153 | StorageWriter.writeIntTensor = createTypedWriteTensorMethod(torch.IntTensor) 154 | StorageWriter.writeLongTensor = createTypedWriteTensorMethod(torch.LongTensor) 155 | StorageWriter.writeFloatTensor = createTypedWriteTensorMethod(torch.FloatTensor) 156 | StorageWriter.writeDoubleTensor = createTypedWriteTensorMethod(torch.DoubleTensor) 157 | 158 | function StorageWriter:getHandler(message_type) 159 | return self.serialization_handlers and self.serialization_handlers[message_type] 160 | end 161 | --------------------------------------------------------------------------------